scdesigner 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scdesigner might be problematic. Click here for more details.
- scdesigner/__init__.py +0 -0
- scdesigner/data/__init__.py +16 -0
- scdesigner/data/formula.py +137 -0
- scdesigner/data/group.py +123 -0
- scdesigner/data/sparse.py +39 -0
- scdesigner/diagnose/__init__.py +65 -0
- scdesigner/diagnose/aic_bic.py +119 -0
- scdesigner/diagnose/plot.py +242 -0
- scdesigner/estimators/__init__.py +27 -0
- scdesigner/estimators/bernoulli.py +85 -0
- scdesigner/estimators/gaussian.py +121 -0
- scdesigner/estimators/gaussian_copula_factory.py +152 -0
- scdesigner/estimators/glm_factory.py +75 -0
- scdesigner/estimators/negbin.py +129 -0
- scdesigner/estimators/pnmf.py +160 -0
- scdesigner/estimators/poisson.py +100 -0
- scdesigner/estimators/zero_inflated_negbin.py +195 -0
- scdesigner/estimators/zero_inflated_poisson.py +85 -0
- scdesigner/format/__init__.py +4 -0
- scdesigner/format/format.py +20 -0
- scdesigner/format/print.py +30 -0
- scdesigner/minimal/__init__.py +17 -0
- scdesigner/minimal/bernoulli.py +61 -0
- scdesigner/minimal/composite.py +119 -0
- scdesigner/minimal/copula.py +33 -0
- scdesigner/minimal/formula.py +23 -0
- scdesigner/minimal/gaussian.py +65 -0
- scdesigner/minimal/kwargs.py +24 -0
- scdesigner/minimal/loader.py +166 -0
- scdesigner/minimal/marginal.py +140 -0
- scdesigner/minimal/negbin.py +73 -0
- scdesigner/minimal/positive_nonnegative_matrix_factorization.py +231 -0
- scdesigner/minimal/scd3.py +95 -0
- scdesigner/minimal/scd3_instances.py +50 -0
- scdesigner/minimal/simulator.py +25 -0
- scdesigner/minimal/standard_covariance.py +124 -0
- scdesigner/minimal/transform.py +145 -0
- scdesigner/minimal/zero_inflated_negbin.py +86 -0
- scdesigner/predictors/__init__.py +15 -0
- scdesigner/predictors/bernoulli.py +9 -0
- scdesigner/predictors/gaussian.py +16 -0
- scdesigner/predictors/negbin.py +17 -0
- scdesigner/predictors/poisson.py +12 -0
- scdesigner/predictors/zero_inflated_negbin.py +18 -0
- scdesigner/predictors/zero_inflated_poisson.py +18 -0
- scdesigner/samplers/__init__.py +23 -0
- scdesigner/samplers/bernoulli.py +27 -0
- scdesigner/samplers/gaussian.py +25 -0
- scdesigner/samplers/glm_factory.py +41 -0
- scdesigner/samplers/negbin.py +25 -0
- scdesigner/samplers/poisson.py +25 -0
- scdesigner/samplers/zero_inflated_negbin.py +40 -0
- scdesigner/samplers/zero_inflated_poisson.py +16 -0
- scdesigner/simulators/__init__.py +31 -0
- scdesigner/simulators/composite_regressor.py +72 -0
- scdesigner/simulators/glm_simulator.py +167 -0
- scdesigner/simulators/pnmf_regression.py +61 -0
- scdesigner/transform/__init__.py +7 -0
- scdesigner/transform/amplify.py +14 -0
- scdesigner/transform/mask.py +33 -0
- scdesigner/transform/nullify.py +25 -0
- scdesigner/transform/split.py +23 -0
- scdesigner/transform/substitute.py +14 -0
- scdesigner-0.0.1.dist-info/METADATA +23 -0
- scdesigner-0.0.1.dist-info/RECORD +66 -0
- scdesigner-0.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
import scanpy as sc
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import altair as alt
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def adata_df(adata):
|
|
9
|
+
return (
|
|
10
|
+
pd.DataFrame(adata.X, columns=adata.var_names)
|
|
11
|
+
.melt(id_vars=[], value_vars=adata.var_names)
|
|
12
|
+
.reset_index(drop=True)
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def merge_samples(adata, sim):
|
|
17
|
+
source = adata_df(adata)
|
|
18
|
+
simulated = adata_df(sim)
|
|
19
|
+
return pd.concat(
|
|
20
|
+
{"real": source, "simulated": simulated}, names=["source"]
|
|
21
|
+
).reset_index(level="source")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def plot_umap(
|
|
25
|
+
adata,
|
|
26
|
+
color=None,
|
|
27
|
+
shape=None,
|
|
28
|
+
facet=None,
|
|
29
|
+
opacity=0.6,
|
|
30
|
+
n_comps=20,
|
|
31
|
+
n_neighbors=15,
|
|
32
|
+
transform=lambda x: np.log1p(x),
|
|
33
|
+
**kwargs
|
|
34
|
+
):
|
|
35
|
+
mapping = {"x": "UMAP1", "y": "UMAP2", "color": color, "shape": shape}
|
|
36
|
+
mapping = {k: v for k, v in mapping.items() if v is not None}
|
|
37
|
+
|
|
38
|
+
adata_ = adata.copy()
|
|
39
|
+
adata_.X = check_sparse(adata_.X)
|
|
40
|
+
Z = transform(adata_.X)
|
|
41
|
+
if Z.shape[1] == adata_.X.shape[1]:
|
|
42
|
+
adata_.X = transform(adata_.X)
|
|
43
|
+
else:
|
|
44
|
+
adata_ = adata_[:, : Z.shape[1]]
|
|
45
|
+
adata_.X = Z
|
|
46
|
+
adata_.var_names = [f"transform_{k}" for k in range(Z.shape[1])]
|
|
47
|
+
|
|
48
|
+
# umap on the top PCA dimensions
|
|
49
|
+
sc.pp.pca(adata_, n_comps=n_comps)
|
|
50
|
+
sc.pp.neighbors(adata_, n_neighbors=n_neighbors, n_pcs=n_comps)
|
|
51
|
+
sc.tl.umap(adata_, **kwargs)
|
|
52
|
+
|
|
53
|
+
# get umap embeddings
|
|
54
|
+
umap_df = pd.DataFrame(adata_.obsm["X_umap"], columns=["UMAP1", "UMAP2"])
|
|
55
|
+
umap_df = pd.concat([umap_df, adata_.obs.reset_index(drop=True)], axis=1)
|
|
56
|
+
|
|
57
|
+
# encode and visualize
|
|
58
|
+
chart = alt.Chart(umap_df).mark_point(opacity=opacity).encode(**mapping)
|
|
59
|
+
if facet is not None:
|
|
60
|
+
chart = chart.facet(column=alt.Facet(facet))
|
|
61
|
+
return chart
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def plot_pca(
|
|
65
|
+
adata,
|
|
66
|
+
color=None,
|
|
67
|
+
shape=None,
|
|
68
|
+
facet=None,
|
|
69
|
+
opacity=0.6,
|
|
70
|
+
plot_dims=[0, 1],
|
|
71
|
+
transform=lambda x: np.log1p(x),
|
|
72
|
+
**kwargs
|
|
73
|
+
):
|
|
74
|
+
mapping = {"x": "PCA1", "y": "PCA2", "color": color, "shape": shape}
|
|
75
|
+
mapping = {k: v for k, v in mapping.items() if v is not None}
|
|
76
|
+
|
|
77
|
+
adata_ = adata.copy()
|
|
78
|
+
adata_.X = check_sparse(adata_.X)
|
|
79
|
+
adata_.X = transform(adata_.X)
|
|
80
|
+
|
|
81
|
+
# get PCA scores
|
|
82
|
+
sc.pp.pca(adata_, **kwargs)
|
|
83
|
+
pca_df = pd.DataFrame(adata_.obsm["X_pca"][:, plot_dims], columns=["PCA1", "PCA2"])
|
|
84
|
+
pca_df = pd.concat([pca_df, adata_.obs.reset_index(drop=True)], axis=1)
|
|
85
|
+
|
|
86
|
+
# plot
|
|
87
|
+
chart = alt.Chart(pca_df).mark_point(opacity=opacity).encode(**mapping)
|
|
88
|
+
if facet is not None:
|
|
89
|
+
chart = chart.facet(column=alt.Facet(facet))
|
|
90
|
+
return chart
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def compare_summary(real, simulated, summary_fun):
|
|
94
|
+
df = pd.DataFrame({"real": summary_fun(real), "simulated": summary_fun(simulated)})
|
|
95
|
+
|
|
96
|
+
identity = pd.DataFrame(
|
|
97
|
+
{
|
|
98
|
+
"real": [df["real"].min(), df["real"].max()],
|
|
99
|
+
"simulated": [df["real"].min(), df["real"].max()],
|
|
100
|
+
}
|
|
101
|
+
)
|
|
102
|
+
return alt.Chart(identity).mark_line(color="#dedede").encode(
|
|
103
|
+
x="real", y="simulated"
|
|
104
|
+
) + alt.Chart(df).mark_circle().encode(x="real", y="simulated")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def check_sparse(X):
|
|
108
|
+
if not isinstance(X, np.ndarray):
|
|
109
|
+
X = X.todense()
|
|
110
|
+
return X
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def compare_means(real, simulated, transform=lambda x: x):
|
|
114
|
+
real_, simulated_ = prepare_dense(real, simulated)
|
|
115
|
+
summary = lambda a: np.asarray(transform(a.X).mean(axis=0)).flatten()
|
|
116
|
+
return compare_summary(real_, simulated_, summary)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def prepare_dense(real, simulated):
|
|
120
|
+
real_ = real.copy()
|
|
121
|
+
simulated_ = simulated.copy()
|
|
122
|
+
real_.X = check_sparse(real_.X)
|
|
123
|
+
simulated_.X = check_sparse(simulated_.X)
|
|
124
|
+
return real_, simulated_
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def compare_variances(real, simulated, transform=lambda x: x):
|
|
128
|
+
real_, simulated_ = prepare_dense(real, simulated)
|
|
129
|
+
summary = lambda a: np.asarray(np.var(transform(a.X), axis=0)).flatten()
|
|
130
|
+
return compare_summary(real_, simulated_, summary)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def compare_standard_deviation(real, simulated, transform=lambda x: x):
|
|
134
|
+
real_, simulated_ = prepare_dense(real, simulated)
|
|
135
|
+
summary = lambda a: np.asarray(np.std(transform(a.X), axis=0)).flatten()
|
|
136
|
+
return compare_summary(real_, simulated_, summary)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def concat_real_sim(real, simulated):
|
|
140
|
+
real_, simulated_ = prepare_dense(real, simulated)
|
|
141
|
+
real_.obs["source"] = "real"
|
|
142
|
+
simulated_.obs["source"] = "simulated"
|
|
143
|
+
return real_.concatenate(simulated_, join="outer", batch_key=None)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def compare_umap(real, simulated, transform=lambda x: x, **kwargs):
|
|
147
|
+
adata = concat_real_sim(real, simulated)
|
|
148
|
+
return plot_umap(adata, facet="source", transform=transform, **kwargs)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def compare_pca(real, simulated, transform=lambda x: x, **kwargs):
|
|
152
|
+
adata = concat_real_sim(real, simulated)
|
|
153
|
+
return plot_pca(adata, facet="source", transform=transform, **kwargs)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def plot_hist(sim_data, real_data, idx):
|
|
157
|
+
sim = sim_data[:, idx]
|
|
158
|
+
real = real_data[:, idx]
|
|
159
|
+
b = np.linspace(min(min(sim), min(real)), max(max(sim), max(real)), 50)
|
|
160
|
+
|
|
161
|
+
plt.hist([real, sim], b, label=["Real", "Simulated"], histtype="bar")
|
|
162
|
+
plt.xlabel("x")
|
|
163
|
+
plt.ylabel("Density")
|
|
164
|
+
plt.legend()
|
|
165
|
+
plt.show()
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def compare_ecdf(adata, sim, var_names=None, max_plot=10, n_cols=5, **kwargs):
|
|
169
|
+
if var_names is None:
|
|
170
|
+
var_names = adata.var_names[:max_plot]
|
|
171
|
+
|
|
172
|
+
combined = merge_samples(adata[:, var_names], sim.sample()[:, var_names])
|
|
173
|
+
alt.data_transformers.enable("vegafusion")
|
|
174
|
+
|
|
175
|
+
plot = (
|
|
176
|
+
alt.Chart(combined)
|
|
177
|
+
.transform_window(
|
|
178
|
+
ecdf="cume_dist()", sort=[{"field": "value"}], groupby=["variable"]
|
|
179
|
+
)
|
|
180
|
+
.mark_line(
|
|
181
|
+
interpolate="step-after",
|
|
182
|
+
)
|
|
183
|
+
.encode(
|
|
184
|
+
x="value:Q",
|
|
185
|
+
y="ecdf:Q",
|
|
186
|
+
color="source:N",
|
|
187
|
+
facet=alt.Facet(
|
|
188
|
+
"variable", sort=alt.EncodingSortField("value"), columns=n_cols
|
|
189
|
+
),
|
|
190
|
+
)
|
|
191
|
+
.properties(**kwargs)
|
|
192
|
+
)
|
|
193
|
+
plot.show()
|
|
194
|
+
return plot, combined
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def compare_boxplot(adata, sim, var_names=None, max_plot=20, **kwargs):
|
|
198
|
+
if var_names is None:
|
|
199
|
+
var_names = adata.var_names[:max_plot]
|
|
200
|
+
|
|
201
|
+
combined = merge_samples(adata[:, var_names], sim.sample()[:, var_names])
|
|
202
|
+
alt.data_transformers.enable("vegafusion")
|
|
203
|
+
|
|
204
|
+
plot = (
|
|
205
|
+
alt.Chart(combined)
|
|
206
|
+
.mark_boxplot(extent="min-max")
|
|
207
|
+
.encode(
|
|
208
|
+
x=alt.X("value:Q").scale(zero=False),
|
|
209
|
+
y=alt.Y(
|
|
210
|
+
"variable:N",
|
|
211
|
+
sort=alt.EncodingSortField("mid_box_value", order="descending"),
|
|
212
|
+
),
|
|
213
|
+
facet="source:N",
|
|
214
|
+
)
|
|
215
|
+
.properties(**kwargs)
|
|
216
|
+
)
|
|
217
|
+
plot.show()
|
|
218
|
+
return plot, combined
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def compare_histogram(adata, sim, var_names=None, max_plot=20, n_cols=5, **kwargs):
|
|
222
|
+
if var_names is None:
|
|
223
|
+
var_names = adata.var_names[:max_plot]
|
|
224
|
+
|
|
225
|
+
combined = merge_samples(adata[:, var_names], sim.sample()[:, var_names])
|
|
226
|
+
alt.data_transformers.enable("vegafusion")
|
|
227
|
+
|
|
228
|
+
plot = (
|
|
229
|
+
alt.Chart(combined)
|
|
230
|
+
.mark_bar(opacity=0.7)
|
|
231
|
+
.encode(
|
|
232
|
+
x=alt.X("value:Q").bin(maxbins=20),
|
|
233
|
+
y=alt.Y("count()").stack(None),
|
|
234
|
+
color="source:N",
|
|
235
|
+
facet=alt.Facet(
|
|
236
|
+
"variable", sort=alt.EncodingSortField("bin_maxbins_20_value")
|
|
237
|
+
),
|
|
238
|
+
)
|
|
239
|
+
.properties(**kwargs)
|
|
240
|
+
)
|
|
241
|
+
plot.show()
|
|
242
|
+
return plot, combined
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from .negbin import negbin_regression, negbin_copula
|
|
2
|
+
from .gaussian_copula_factory import group_indices
|
|
3
|
+
from .poisson import poisson_regression, poisson_copula
|
|
4
|
+
from .bernoulli import bernoulli_regression, bernoulli_copula
|
|
5
|
+
from .gaussian import gaussian_regression, gaussian_copula
|
|
6
|
+
from .zero_inflated_negbin import (
|
|
7
|
+
zero_inflated_negbin_regression,
|
|
8
|
+
zero_inflated_negbin_copula,
|
|
9
|
+
)
|
|
10
|
+
from .zero_inflated_poisson import zero_inflated_poisson_regression
|
|
11
|
+
from .glm_factory import multiple_formula_regression_factory
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"bernoulli_regression",
|
|
15
|
+
"bernoulli_copula",
|
|
16
|
+
"negbin_copula",
|
|
17
|
+
"negbin_regression",
|
|
18
|
+
"gaussian_regression",
|
|
19
|
+
"gaussian_copula",
|
|
20
|
+
"group_indices",
|
|
21
|
+
"poisson_copula",
|
|
22
|
+
"poisson_regression",
|
|
23
|
+
"zero_inflated_negbin_copula",
|
|
24
|
+
"zero_inflated_negbin_regression",
|
|
25
|
+
"zero_inflated_poisson_regression",
|
|
26
|
+
"multiple_formula_regression_factory",
|
|
27
|
+
]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from . import gaussian_copula_factory as gcf
|
|
3
|
+
from . import glm_factory as factory
|
|
4
|
+
from .. import data
|
|
5
|
+
from .. import format
|
|
6
|
+
from . import poisson as poi
|
|
7
|
+
from anndata import AnnData
|
|
8
|
+
from scipy.stats import bernoulli
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
from typing import Union
|
|
12
|
+
|
|
13
|
+
###############################################################################
|
|
14
|
+
## Regression functions that operate on numpy arrays
|
|
15
|
+
###############################################################################
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def bernoulli_regression_likelihood(params, X_dict, y):
|
|
19
|
+
# get appropriate parameter shape
|
|
20
|
+
n_features = X_dict["mean"].shape[1]
|
|
21
|
+
n_outcomes = y.shape[1]
|
|
22
|
+
|
|
23
|
+
# compute the negative log likelihood
|
|
24
|
+
coef_mean = params.reshape(n_features, n_outcomes)
|
|
25
|
+
theta = torch.sigmoid(X_dict["mean"] @ coef_mean)
|
|
26
|
+
log_likelihood = y * torch.log(theta) + (1 - y) * torch.log(1 - theta)
|
|
27
|
+
return -torch.sum(log_likelihood)
|
|
28
|
+
|
|
29
|
+
def bernoulli_initializer(X_dict, y, device):
|
|
30
|
+
n_features = X_dict["mean"].shape[1]
|
|
31
|
+
n_outcomes = y.shape[1]
|
|
32
|
+
return torch.zeros(n_features * n_outcomes, requires_grad=True, device=device)
|
|
33
|
+
|
|
34
|
+
def bernoulli_postprocessor(params, X_dict, y):
|
|
35
|
+
coef_mean = format.to_np(params).reshape(X_dict["mean"].shape[1], y.shape[1])
|
|
36
|
+
return {"coef_mean": coef_mean}
|
|
37
|
+
|
|
38
|
+
bernoulli_regression_array = factory.multiple_formula_regression_factory(
|
|
39
|
+
bernoulli_regression_likelihood, bernoulli_initializer, bernoulli_postprocessor
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
###############################################################################
|
|
43
|
+
## Regression functions that operate on AnnData objects
|
|
44
|
+
###############################################################################
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def bernoulli_regression(
|
|
48
|
+
adata: AnnData, formula: Union[str, dict], chunk_size: int = int(1e4), batch_size=512, **kwargs
|
|
49
|
+
) -> dict:
|
|
50
|
+
formula = data.standardize_formula(formula, allowed_keys=['mean'])
|
|
51
|
+
loaders = data.multiple_formula_loader(
|
|
52
|
+
adata, formula, chunk_size=chunk_size, batch_size=batch_size
|
|
53
|
+
)
|
|
54
|
+
parameters = bernoulli_regression_array(loaders, **kwargs)
|
|
55
|
+
return poi.format_poisson_parameters(
|
|
56
|
+
parameters, list(adata.var_names), list(loaders["mean"].dataset.x_names)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
###############################################################################
|
|
61
|
+
## Copula versions for bernoulli regression
|
|
62
|
+
###############################################################################
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def bernoulli_uniformizer(parameters, X_dict, y, epsilon=1e-3, random_seed=42):
|
|
66
|
+
np.random.seed(random_seed)
|
|
67
|
+
theta = torch.sigmoid(torch.from_numpy(X_dict["mean"] @ parameters["coef_mean"])).numpy()
|
|
68
|
+
u1 = bernoulli(theta).cdf(y)
|
|
69
|
+
u2 = np.where(y > 0, bernoulli(theta).cdf(y - 1), 0)
|
|
70
|
+
v = np.random.uniform(size=y.shape)
|
|
71
|
+
return np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
|
|
72
|
+
|
|
73
|
+
def format_bernoulli_parameters_with_loaders(parameters: dict, var_names: list, dls: dict) -> dict:
|
|
74
|
+
coef_mean_index = dls["mean"].dataset.x_names
|
|
75
|
+
|
|
76
|
+
parameters["coef_mean"] = pd.DataFrame(
|
|
77
|
+
parameters["coef_mean"], columns=var_names, index=coef_mean_index
|
|
78
|
+
)
|
|
79
|
+
return parameters
|
|
80
|
+
|
|
81
|
+
bernoulli_copula = gcf.gaussian_copula_factory(
|
|
82
|
+
gcf.gaussian_copula_array_factory(bernoulli_regression_array, bernoulli_uniformizer),
|
|
83
|
+
format_bernoulli_parameters_with_loaders,
|
|
84
|
+
param_name=['mean']
|
|
85
|
+
)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from . import gaussian_copula_factory as gcf
|
|
2
|
+
from . import glm_factory as factory
|
|
3
|
+
from .. import format
|
|
4
|
+
from .. import data
|
|
5
|
+
from anndata import AnnData
|
|
6
|
+
from scipy.stats import norm
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import torch
|
|
10
|
+
from typing import Union
|
|
11
|
+
|
|
12
|
+
###############################################################################
|
|
13
|
+
## Regression functions that operate on numpy arrays
|
|
14
|
+
###############################################################################
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Gaussian regression likelihood: regression onto mean and sdev
|
|
19
|
+
def gaussian_regression_likelihood(params, X_dict, y):
|
|
20
|
+
n_mean_features = X_dict["mean"].shape[1]
|
|
21
|
+
n_sdev_features = X_dict["sdev"].shape[1]
|
|
22
|
+
n_outcomes = y.shape[1]
|
|
23
|
+
|
|
24
|
+
coef_mean = params[: n_mean_features * n_outcomes].reshape(n_mean_features, n_outcomes)
|
|
25
|
+
coef_sdev = params[n_mean_features * n_outcomes :].reshape(n_sdev_features, n_outcomes)
|
|
26
|
+
mu = X_dict["mean"] @ coef_mean
|
|
27
|
+
sigma = torch.exp(X_dict["sdev"] @ coef_sdev)
|
|
28
|
+
|
|
29
|
+
# Negative log likelihood for Gaussian
|
|
30
|
+
log_likelihood = -0.5 * (torch.log(2 * torch.pi * sigma ** 2) + ((y - mu) ** 2) / (sigma ** 2))
|
|
31
|
+
return torch.sum(log_likelihood)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def gaussian_initializer(x_dict, y, device):
|
|
36
|
+
n_mean_features = x_dict["mean"].shape[1]
|
|
37
|
+
n_outcomes = y.shape[1]
|
|
38
|
+
n_sdev_features = x_dict["sdev"].shape[1]
|
|
39
|
+
return torch.zeros(
|
|
40
|
+
n_mean_features * n_outcomes + n_sdev_features * n_outcomes,
|
|
41
|
+
requires_grad=True, device=device
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def gaussian_postprocessor(params, x_dict, y):
|
|
47
|
+
n_mean_features = x_dict["mean"].shape[1]
|
|
48
|
+
n_outcomes = y.shape[1]
|
|
49
|
+
n_sdev_features = x_dict["sdev"].shape[1]
|
|
50
|
+
coef_mean = format.to_np(params[:n_mean_features * n_outcomes]).reshape(n_mean_features, n_outcomes)
|
|
51
|
+
coef_sdev = format.to_np(params[n_mean_features * n_outcomes:]).reshape(n_sdev_features, n_outcomes)
|
|
52
|
+
return {"coef_mean": coef_mean, "coef_sdev": coef_sdev}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
gaussian_regression_array = factory.multiple_formula_regression_factory(
|
|
57
|
+
gaussian_regression_likelihood, gaussian_initializer, gaussian_postprocessor
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
###############################################################################
|
|
62
|
+
## Regression functions that operate on AnnData objects
|
|
63
|
+
###############################################################################
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def format_gaussian_parameters(
|
|
67
|
+
parameters: dict, var_names: list, mean_coef_index: list, sdev_coef_index: list
|
|
68
|
+
) -> dict:
|
|
69
|
+
parameters["coef_mean"] = pd.DataFrame(
|
|
70
|
+
parameters["coef_mean"], columns=var_names, index=mean_coef_index
|
|
71
|
+
)
|
|
72
|
+
parameters["coef_sdev"] = pd.DataFrame(
|
|
73
|
+
parameters["coef_sdev"], columns=var_names, index=sdev_coef_index
|
|
74
|
+
)
|
|
75
|
+
return parameters
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def format_gaussian_parameters_with_loaders(
|
|
79
|
+
parameters: dict, var_names: list, dls: dict
|
|
80
|
+
) -> dict:
|
|
81
|
+
mean_coef_index = dls["mean"].dataset.x_names
|
|
82
|
+
sdev_coef_index = dls["sdev"].dataset.x_names
|
|
83
|
+
return format_gaussian_parameters(
|
|
84
|
+
parameters, var_names, mean_coef_index, sdev_coef_index
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def gaussian_regression(
|
|
89
|
+
adata: AnnData, formula: Union[str, dict], chunk_size: int = int(1e4),
|
|
90
|
+
batch_size=512, **kwargs
|
|
91
|
+
) -> dict:
|
|
92
|
+
formula = data.standardize_formula(formula, allowed_keys=['mean', 'sdev'])
|
|
93
|
+
loaders = data.multiple_formula_loader(
|
|
94
|
+
adata, formula, chunk_size=chunk_size, batch_size=batch_size
|
|
95
|
+
)
|
|
96
|
+
parameters = gaussian_regression_array(loaders, **kwargs)
|
|
97
|
+
return format_gaussian_parameters(
|
|
98
|
+
parameters, list(adata.var_names), loaders["mean"].dataset.x_names,
|
|
99
|
+
loaders["sdev"].dataset.x_names
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
###############################################################################
|
|
104
|
+
## Copula versions for gaussian regression
|
|
105
|
+
###############################################################################
|
|
106
|
+
|
|
107
|
+
def gaussian_uniformizer(parameters, X_dict, y, epsilon=1e-3):
|
|
108
|
+
mu = X_dict["mean"] @ parameters["coef_mean"]
|
|
109
|
+
sigma = np.exp(X_dict["sdev"] @ parameters["coef_sdev"])
|
|
110
|
+
u = norm.cdf(y, loc=mu, scale=sigma)
|
|
111
|
+
u = np.clip(u, epsilon, 1 - epsilon)
|
|
112
|
+
return u
|
|
113
|
+
|
|
114
|
+
gaussian_copula_array = gcf.gaussian_copula_array_factory(
|
|
115
|
+
gaussian_regression_array, gaussian_uniformizer
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
gaussian_copula = gcf.gaussian_copula_factory(
|
|
119
|
+
gaussian_copula_array, format_gaussian_parameters_with_loaders,
|
|
120
|
+
param_name=['mean', 'sdev']
|
|
121
|
+
)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
from ..data import stack_collate, multiple_formula_group_loader
|
|
2
|
+
from .. import data
|
|
3
|
+
from anndata import AnnData
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import Union
|
|
6
|
+
from scipy.stats import norm
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
###############################################################################
|
|
12
|
+
## General copula factory functions
|
|
13
|
+
###############################################################################
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def gaussian_copula_array_factory(marginal_model: Callable, uniformizer: Callable):
|
|
17
|
+
def copula_fun(loaders: dict[str, DataLoader], lr: float = 0.1, epochs: int = 40, **kwargs):
|
|
18
|
+
# for the marginal model, ignore the groupings
|
|
19
|
+
# Strip all dataloaders and create a dictionary to pass to marginal_model
|
|
20
|
+
formula_loaders = {}
|
|
21
|
+
for key in loaders.keys():
|
|
22
|
+
formula_loaders[key] = strip_dataloader(loaders[key], pop="Stack" in type(loaders[key].dataset).__name__)
|
|
23
|
+
|
|
24
|
+
# Call marginal_model with the dictionary of stripped dataloaders
|
|
25
|
+
parameters = marginal_model(formula_loaders, lr=lr, epochs=epochs, **kwargs)
|
|
26
|
+
|
|
27
|
+
# estimate covariance, allowing for different groups
|
|
28
|
+
parameters["covariance"] = copula_covariance(parameters, loaders, uniformizer)
|
|
29
|
+
return parameters
|
|
30
|
+
|
|
31
|
+
return copula_fun
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def gaussian_copula_factory(copula_array_fun: Callable,
|
|
35
|
+
parameter_formatter: Callable,
|
|
36
|
+
param_name: list = None):
|
|
37
|
+
def copula_fun(
|
|
38
|
+
adata: AnnData,
|
|
39
|
+
formula: Union[str, dict] = "~ 1",
|
|
40
|
+
grouping_var: str = None,
|
|
41
|
+
chunk_size: int = int(1e4),
|
|
42
|
+
batch_size: int = 512,
|
|
43
|
+
**kwargs
|
|
44
|
+
) -> dict:
|
|
45
|
+
|
|
46
|
+
if param_name is not None:
|
|
47
|
+
formula = data.standardize_formula(formula, param_name)
|
|
48
|
+
|
|
49
|
+
dls = multiple_formula_group_loader(
|
|
50
|
+
adata,
|
|
51
|
+
formula,
|
|
52
|
+
grouping_var,
|
|
53
|
+
chunk_size=chunk_size,
|
|
54
|
+
batch_size=batch_size,
|
|
55
|
+
) # returns a dictionary of dataloaders
|
|
56
|
+
parameters = copula_array_fun(dls, **kwargs)
|
|
57
|
+
|
|
58
|
+
# Pass the full dls to parameter_formatter so it can extract what it needs
|
|
59
|
+
parameters = parameter_formatter(
|
|
60
|
+
parameters, adata.var_names, dls
|
|
61
|
+
)
|
|
62
|
+
parameters["covariance"] = format_copula_parameters(parameters, adata.var_names)
|
|
63
|
+
return parameters
|
|
64
|
+
|
|
65
|
+
return copula_fun
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def copula_covariance(parameters: dict, loaders: dict[str, DataLoader], uniformizer: Callable):
|
|
69
|
+
first_loader = next(iter(loaders.values()))
|
|
70
|
+
D = next(iter(first_loader))[1].shape[1] #dimension of y
|
|
71
|
+
groups = first_loader.dataset.groups # a list of strings of group names
|
|
72
|
+
sums = {g: np.zeros(D) for g in groups}
|
|
73
|
+
second_moments = {g: np.eye(D) for g in groups}
|
|
74
|
+
Ng = {g: 0 for g in groups}
|
|
75
|
+
keys = list(loaders.keys())
|
|
76
|
+
loaders = list(loaders.values())
|
|
77
|
+
num_keys = len(keys)
|
|
78
|
+
|
|
79
|
+
for batches in zip(*loaders):
|
|
80
|
+
x_batch_dict = {
|
|
81
|
+
keys[i]: batches[i][0].cpu().numpy() for i in range(num_keys)
|
|
82
|
+
}
|
|
83
|
+
y_batch = batches[0][1].cpu().numpy()
|
|
84
|
+
memberships = batches[0][2] # should be identical for all keys
|
|
85
|
+
|
|
86
|
+
u = uniformizer(parameters, x_batch_dict, y_batch)
|
|
87
|
+
for g in groups:
|
|
88
|
+
ix = np.where(np.array(memberships) == g)
|
|
89
|
+
z = norm().ppf(u[ix])
|
|
90
|
+
second_moments[g] += z.T @ z
|
|
91
|
+
sums[g] += z.sum(axis=0)
|
|
92
|
+
Ng[g] += len(ix[0])
|
|
93
|
+
|
|
94
|
+
result = {}
|
|
95
|
+
for g in groups:
|
|
96
|
+
mean = sums[g] / Ng[g]
|
|
97
|
+
result[g] = second_moments[g] / Ng[g] - np.outer(mean, mean)
|
|
98
|
+
|
|
99
|
+
if len(groups) == 1:
|
|
100
|
+
return list(result.values())[0]
|
|
101
|
+
return result
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
###############################################################################
|
|
105
|
+
## Helpers to prepare and postprocess copula parameters
|
|
106
|
+
###############################################################################
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def group_indices(grouping_var: str, obs: pd.DataFrame) -> dict:
|
|
110
|
+
"""
|
|
111
|
+
Returns a dictionary of group indices for each group in the grouping variable.
|
|
112
|
+
"""
|
|
113
|
+
if grouping_var is None:
|
|
114
|
+
grouping_var = "_copula_group"
|
|
115
|
+
if "copula_group" not in obs.columns:
|
|
116
|
+
obs["_copula_group"] = pd.Categorical(["shared_group"] * len(obs))
|
|
117
|
+
result = {}
|
|
118
|
+
|
|
119
|
+
for group in list(obs[grouping_var].dtype.categories):
|
|
120
|
+
result[group] = np.where(obs[grouping_var].values == group)[0]
|
|
121
|
+
return result
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def clip(u: np.array, min: float = 1e-5, max: float = 1 - 1e-5) -> np.array:
|
|
125
|
+
u[u < min] = min
|
|
126
|
+
u[u > max] = max
|
|
127
|
+
return u
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def format_copula_parameters(parameters: dict, var_names: list):
|
|
131
|
+
covariance = parameters["covariance"]
|
|
132
|
+
if type(covariance) is not dict:
|
|
133
|
+
covariance = pd.DataFrame(
|
|
134
|
+
parameters["covariance"], columns=list(var_names), index=list(var_names)
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
for group in covariance.keys():
|
|
138
|
+
covariance[group] = pd.DataFrame(
|
|
139
|
+
parameters["covariance"][group],
|
|
140
|
+
columns=list(var_names),
|
|
141
|
+
index=list(var_names),
|
|
142
|
+
)
|
|
143
|
+
return covariance
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def strip_dataloader(dataloader, pop=False):
|
|
147
|
+
return DataLoader(
|
|
148
|
+
dataset=dataloader.dataset,
|
|
149
|
+
batch_sampler=dataloader.batch_sampler,
|
|
150
|
+
collate_fn=stack_collate(pop=pop, groups=False),
|
|
151
|
+
)
|
|
152
|
+
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from tqdm import tqdm
|
|
2
|
+
from torch.utils.data import DataLoader
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def glm_regression_factory(likelihood, initializer, postprocessor) -> dict:
|
|
7
|
+
def estimator(
|
|
8
|
+
dataloader: DataLoader,
|
|
9
|
+
lr: float = 0.1,
|
|
10
|
+
epochs: int = 40,
|
|
11
|
+
):
|
|
12
|
+
device = check_device()
|
|
13
|
+
x, y = next(iter(dataloader))
|
|
14
|
+
params = initializer(x, y, device)
|
|
15
|
+
optimizer = torch.optim.Adam([params], lr=lr)
|
|
16
|
+
|
|
17
|
+
for epoch in range(epochs):
|
|
18
|
+
for x_batch, y_batch in (pbar := tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}", leave=False)):
|
|
19
|
+
optimizer.zero_grad()
|
|
20
|
+
loss = likelihood(params, x_batch, y_batch)
|
|
21
|
+
loss.backward()
|
|
22
|
+
optimizer.step()
|
|
23
|
+
pbar.set_postfix_str(f"loss: {loss.item()}")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
return postprocessor(params, x.shape[1], y.shape[1])
|
|
27
|
+
|
|
28
|
+
return estimator
|
|
29
|
+
|
|
30
|
+
def multiple_formula_regression_factory(likelihood, initializer, postprocessor) -> dict:
|
|
31
|
+
def estimator(
|
|
32
|
+
dataloaders: dict[str, DataLoader],
|
|
33
|
+
lr: float = 0.1,
|
|
34
|
+
epochs: int = 40,
|
|
35
|
+
):
|
|
36
|
+
device = check_device()
|
|
37
|
+
x_dict = {}
|
|
38
|
+
y_dict = {}
|
|
39
|
+
for key in dataloaders.keys():
|
|
40
|
+
x_dict[key], y_dict[key] = next(iter(dataloaders[key]))
|
|
41
|
+
# check if all ys are the same
|
|
42
|
+
y_ref = y_dict[list(dataloaders.keys())[0]]
|
|
43
|
+
for key in dataloaders.keys():
|
|
44
|
+
if not torch.equal(y_dict[key], y_ref):
|
|
45
|
+
raise ValueError(f"Ys are not the same for {key}")
|
|
46
|
+
params = initializer(x_dict, y_ref, device) # x is a dictionary of tensors, y is a tensor
|
|
47
|
+
optimizer = torch.optim.Adam([params], lr=lr)
|
|
48
|
+
|
|
49
|
+
keys = list(dataloaders.keys())
|
|
50
|
+
loaders = list(dataloaders.values())
|
|
51
|
+
|
|
52
|
+
for epoch in range(epochs):
|
|
53
|
+
num_keys = len(keys)
|
|
54
|
+
for batches in (pbar := tqdm(zip(*loaders), desc=f"Epoch {epoch + 1}/{epochs}", leave=False)):
|
|
55
|
+
x_batch_dict = {
|
|
56
|
+
keys[i]: batches[i][0].to(device) for i in range(num_keys)
|
|
57
|
+
}
|
|
58
|
+
y_batch = batches[0][1].to(device)
|
|
59
|
+
optimizer.zero_grad()
|
|
60
|
+
loss = likelihood(params, x_batch_dict, y_batch)
|
|
61
|
+
loss.backward()
|
|
62
|
+
optimizer.step()
|
|
63
|
+
pbar.set_postfix_str(f"loss: {loss.item()}")
|
|
64
|
+
|
|
65
|
+
return postprocessor(params, x_dict, y_ref)
|
|
66
|
+
|
|
67
|
+
return estimator
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def check_device():
|
|
71
|
+
return torch.device(
|
|
72
|
+
"cuda"
|
|
73
|
+
if torch.cuda.is_available()
|
|
74
|
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
75
|
+
)
|