scdesigner 0.0.5__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdesigner/base/__init__.py +8 -0
- scdesigner/base/copula.py +416 -0
- scdesigner/base/marginal.py +391 -0
- scdesigner/base/simulator.py +59 -0
- scdesigner/copulas/__init__.py +8 -0
- scdesigner/copulas/standard_copula.py +645 -0
- scdesigner/datasets/__init__.py +5 -0
- scdesigner/datasets/pancreas.py +39 -0
- scdesigner/distributions/__init__.py +19 -0
- scdesigner/{minimal → distributions}/bernoulli.py +42 -14
- scdesigner/distributions/gaussian.py +114 -0
- scdesigner/distributions/negbin.py +121 -0
- scdesigner/distributions/negbin_irls.py +72 -0
- scdesigner/distributions/negbin_irls_funs.py +456 -0
- scdesigner/distributions/poisson.py +88 -0
- scdesigner/{minimal → distributions}/zero_inflated_negbin.py +39 -10
- scdesigner/distributions/zero_inflated_poisson.py +103 -0
- scdesigner/simulators/__init__.py +24 -28
- scdesigner/simulators/composite.py +239 -0
- scdesigner/simulators/positive_nonnegative_matrix_factorization.py +477 -0
- scdesigner/simulators/scd3.py +486 -0
- scdesigner/transform/__init__.py +8 -6
- scdesigner/{minimal → transform}/transform.py +1 -1
- scdesigner/{minimal → utils}/kwargs.py +4 -1
- {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/METADATA +1 -1
- scdesigner-0.0.10.dist-info/RECORD +28 -0
- {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/WHEEL +1 -1
- scdesigner/data/__init__.py +0 -16
- scdesigner/data/formula.py +0 -137
- scdesigner/data/group.py +0 -123
- scdesigner/data/sparse.py +0 -39
- scdesigner/diagnose/__init__.py +0 -65
- scdesigner/diagnose/aic_bic.py +0 -119
- scdesigner/diagnose/plot.py +0 -242
- scdesigner/estimators/__init__.py +0 -32
- scdesigner/estimators/bernoulli.py +0 -85
- scdesigner/estimators/gaussian.py +0 -121
- scdesigner/estimators/gaussian_copula_factory.py +0 -367
- scdesigner/estimators/glm_factory.py +0 -75
- scdesigner/estimators/negbin.py +0 -153
- scdesigner/estimators/pnmf.py +0 -160
- scdesigner/estimators/poisson.py +0 -124
- scdesigner/estimators/zero_inflated_negbin.py +0 -195
- scdesigner/estimators/zero_inflated_poisson.py +0 -85
- scdesigner/format/__init__.py +0 -4
- scdesigner/format/format.py +0 -20
- scdesigner/format/print.py +0 -30
- scdesigner/minimal/__init__.py +0 -17
- scdesigner/minimal/composite.py +0 -119
- scdesigner/minimal/copula.py +0 -205
- scdesigner/minimal/formula.py +0 -23
- scdesigner/minimal/gaussian.py +0 -65
- scdesigner/minimal/loader.py +0 -211
- scdesigner/minimal/marginal.py +0 -154
- scdesigner/minimal/negbin.py +0 -73
- scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -231
- scdesigner/minimal/scd3.py +0 -96
- scdesigner/minimal/scd3_instances.py +0 -50
- scdesigner/minimal/simulator.py +0 -25
- scdesigner/minimal/standard_copula.py +0 -383
- scdesigner/predictors/__init__.py +0 -15
- scdesigner/predictors/bernoulli.py +0 -9
- scdesigner/predictors/gaussian.py +0 -16
- scdesigner/predictors/negbin.py +0 -17
- scdesigner/predictors/poisson.py +0 -12
- scdesigner/predictors/zero_inflated_negbin.py +0 -18
- scdesigner/predictors/zero_inflated_poisson.py +0 -18
- scdesigner/samplers/__init__.py +0 -23
- scdesigner/samplers/bernoulli.py +0 -27
- scdesigner/samplers/gaussian.py +0 -25
- scdesigner/samplers/glm_factory.py +0 -103
- scdesigner/samplers/negbin.py +0 -25
- scdesigner/samplers/poisson.py +0 -25
- scdesigner/samplers/zero_inflated_negbin.py +0 -40
- scdesigner/samplers/zero_inflated_poisson.py +0 -16
- scdesigner/simulators/composite_regressor.py +0 -72
- scdesigner/simulators/glm_simulator.py +0 -167
- scdesigner/simulators/pnmf_regression.py +0 -61
- scdesigner/transform/amplify.py +0 -14
- scdesigner/transform/mask.py +0 -33
- scdesigner/transform/nullify.py +0 -25
- scdesigner/transform/split.py +0 -23
- scdesigner/transform/substitute.py +0 -14
- scdesigner-0.0.5.dist-info/RECORD +0 -66
scdesigner/diagnose/plot.py
DELETED
|
@@ -1,242 +0,0 @@
|
|
|
1
|
-
import scanpy as sc
|
|
2
|
-
import numpy as np
|
|
3
|
-
import pandas as pd
|
|
4
|
-
import altair as alt
|
|
5
|
-
import matplotlib.pyplot as plt
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def adata_df(adata):
|
|
9
|
-
return (
|
|
10
|
-
pd.DataFrame(adata.X, columns=adata.var_names)
|
|
11
|
-
.melt(id_vars=[], value_vars=adata.var_names)
|
|
12
|
-
.reset_index(drop=True)
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def merge_samples(adata, sim):
|
|
17
|
-
source = adata_df(adata)
|
|
18
|
-
simulated = adata_df(sim)
|
|
19
|
-
return pd.concat(
|
|
20
|
-
{"real": source, "simulated": simulated}, names=["source"]
|
|
21
|
-
).reset_index(level="source")
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def plot_umap(
|
|
25
|
-
adata,
|
|
26
|
-
color=None,
|
|
27
|
-
shape=None,
|
|
28
|
-
facet=None,
|
|
29
|
-
opacity=0.6,
|
|
30
|
-
n_comps=20,
|
|
31
|
-
n_neighbors=15,
|
|
32
|
-
transform=lambda x: np.log1p(x),
|
|
33
|
-
**kwargs
|
|
34
|
-
):
|
|
35
|
-
mapping = {"x": "UMAP1", "y": "UMAP2", "color": color, "shape": shape}
|
|
36
|
-
mapping = {k: v for k, v in mapping.items() if v is not None}
|
|
37
|
-
|
|
38
|
-
adata_ = adata.copy()
|
|
39
|
-
adata_.X = check_sparse(adata_.X)
|
|
40
|
-
Z = transform(adata_.X)
|
|
41
|
-
if Z.shape[1] == adata_.X.shape[1]:
|
|
42
|
-
adata_.X = transform(adata_.X)
|
|
43
|
-
else:
|
|
44
|
-
adata_ = adata_[:, : Z.shape[1]]
|
|
45
|
-
adata_.X = Z
|
|
46
|
-
adata_.var_names = [f"transform_{k}" for k in range(Z.shape[1])]
|
|
47
|
-
|
|
48
|
-
# umap on the top PCA dimensions
|
|
49
|
-
sc.pp.pca(adata_, n_comps=n_comps)
|
|
50
|
-
sc.pp.neighbors(adata_, n_neighbors=n_neighbors, n_pcs=n_comps)
|
|
51
|
-
sc.tl.umap(adata_, **kwargs)
|
|
52
|
-
|
|
53
|
-
# get umap embeddings
|
|
54
|
-
umap_df = pd.DataFrame(adata_.obsm["X_umap"], columns=["UMAP1", "UMAP2"])
|
|
55
|
-
umap_df = pd.concat([umap_df, adata_.obs.reset_index(drop=True)], axis=1)
|
|
56
|
-
|
|
57
|
-
# encode and visualize
|
|
58
|
-
chart = alt.Chart(umap_df).mark_point(opacity=opacity).encode(**mapping)
|
|
59
|
-
if facet is not None:
|
|
60
|
-
chart = chart.facet(column=alt.Facet(facet))
|
|
61
|
-
return chart
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def plot_pca(
|
|
65
|
-
adata,
|
|
66
|
-
color=None,
|
|
67
|
-
shape=None,
|
|
68
|
-
facet=None,
|
|
69
|
-
opacity=0.6,
|
|
70
|
-
plot_dims=[0, 1],
|
|
71
|
-
transform=lambda x: np.log1p(x),
|
|
72
|
-
**kwargs
|
|
73
|
-
):
|
|
74
|
-
mapping = {"x": "PCA1", "y": "PCA2", "color": color, "shape": shape}
|
|
75
|
-
mapping = {k: v for k, v in mapping.items() if v is not None}
|
|
76
|
-
|
|
77
|
-
adata_ = adata.copy()
|
|
78
|
-
adata_.X = check_sparse(adata_.X)
|
|
79
|
-
adata_.X = transform(adata_.X)
|
|
80
|
-
|
|
81
|
-
# get PCA scores
|
|
82
|
-
sc.pp.pca(adata_, **kwargs)
|
|
83
|
-
pca_df = pd.DataFrame(adata_.obsm["X_pca"][:, plot_dims], columns=["PCA1", "PCA2"])
|
|
84
|
-
pca_df = pd.concat([pca_df, adata_.obs.reset_index(drop=True)], axis=1)
|
|
85
|
-
|
|
86
|
-
# plot
|
|
87
|
-
chart = alt.Chart(pca_df).mark_point(opacity=opacity).encode(**mapping)
|
|
88
|
-
if facet is not None:
|
|
89
|
-
chart = chart.facet(column=alt.Facet(facet))
|
|
90
|
-
return chart
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def compare_summary(real, simulated, summary_fun):
|
|
94
|
-
df = pd.DataFrame({"real": summary_fun(real), "simulated": summary_fun(simulated)})
|
|
95
|
-
|
|
96
|
-
identity = pd.DataFrame(
|
|
97
|
-
{
|
|
98
|
-
"real": [df["real"].min(), df["real"].max()],
|
|
99
|
-
"simulated": [df["real"].min(), df["real"].max()],
|
|
100
|
-
}
|
|
101
|
-
)
|
|
102
|
-
return alt.Chart(identity).mark_line(color="#dedede").encode(
|
|
103
|
-
x="real", y="simulated"
|
|
104
|
-
) + alt.Chart(df).mark_circle().encode(x="real", y="simulated")
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
def check_sparse(X):
|
|
108
|
-
if not isinstance(X, np.ndarray):
|
|
109
|
-
X = X.todense()
|
|
110
|
-
return X
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
def compare_means(real, simulated, transform=lambda x: x):
|
|
114
|
-
real_, simulated_ = prepare_dense(real, simulated)
|
|
115
|
-
summary = lambda a: np.asarray(transform(a.X).mean(axis=0)).flatten()
|
|
116
|
-
return compare_summary(real_, simulated_, summary)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def prepare_dense(real, simulated):
|
|
120
|
-
real_ = real.copy()
|
|
121
|
-
simulated_ = simulated.copy()
|
|
122
|
-
real_.X = check_sparse(real_.X)
|
|
123
|
-
simulated_.X = check_sparse(simulated_.X)
|
|
124
|
-
return real_, simulated_
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def compare_variances(real, simulated, transform=lambda x: x):
|
|
128
|
-
real_, simulated_ = prepare_dense(real, simulated)
|
|
129
|
-
summary = lambda a: np.asarray(np.var(transform(a.X), axis=0)).flatten()
|
|
130
|
-
return compare_summary(real_, simulated_, summary)
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
def compare_standard_deviation(real, simulated, transform=lambda x: x):
|
|
134
|
-
real_, simulated_ = prepare_dense(real, simulated)
|
|
135
|
-
summary = lambda a: np.asarray(np.std(transform(a.X), axis=0)).flatten()
|
|
136
|
-
return compare_summary(real_, simulated_, summary)
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
def concat_real_sim(real, simulated):
|
|
140
|
-
real_, simulated_ = prepare_dense(real, simulated)
|
|
141
|
-
real_.obs["source"] = "real"
|
|
142
|
-
simulated_.obs["source"] = "simulated"
|
|
143
|
-
return real_.concatenate(simulated_, join="outer", batch_key=None)
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
def compare_umap(real, simulated, transform=lambda x: x, **kwargs):
|
|
147
|
-
adata = concat_real_sim(real, simulated)
|
|
148
|
-
return plot_umap(adata, facet="source", transform=transform, **kwargs)
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
def compare_pca(real, simulated, transform=lambda x: x, **kwargs):
|
|
152
|
-
adata = concat_real_sim(real, simulated)
|
|
153
|
-
return plot_pca(adata, facet="source", transform=transform, **kwargs)
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
def plot_hist(sim_data, real_data, idx):
|
|
157
|
-
sim = sim_data[:, idx]
|
|
158
|
-
real = real_data[:, idx]
|
|
159
|
-
b = np.linspace(min(min(sim), min(real)), max(max(sim), max(real)), 50)
|
|
160
|
-
|
|
161
|
-
plt.hist([real, sim], b, label=["Real", "Simulated"], histtype="bar")
|
|
162
|
-
plt.xlabel("x")
|
|
163
|
-
plt.ylabel("Density")
|
|
164
|
-
plt.legend()
|
|
165
|
-
plt.show()
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
def compare_ecdf(adata, sim, var_names=None, max_plot=10, n_cols=5, **kwargs):
|
|
169
|
-
if var_names is None:
|
|
170
|
-
var_names = adata.var_names[:max_plot]
|
|
171
|
-
|
|
172
|
-
combined = merge_samples(adata[:, var_names], sim.sample()[:, var_names])
|
|
173
|
-
alt.data_transformers.enable("vegafusion")
|
|
174
|
-
|
|
175
|
-
plot = (
|
|
176
|
-
alt.Chart(combined)
|
|
177
|
-
.transform_window(
|
|
178
|
-
ecdf="cume_dist()", sort=[{"field": "value"}], groupby=["variable"]
|
|
179
|
-
)
|
|
180
|
-
.mark_line(
|
|
181
|
-
interpolate="step-after",
|
|
182
|
-
)
|
|
183
|
-
.encode(
|
|
184
|
-
x="value:Q",
|
|
185
|
-
y="ecdf:Q",
|
|
186
|
-
color="source:N",
|
|
187
|
-
facet=alt.Facet(
|
|
188
|
-
"variable", sort=alt.EncodingSortField("value"), columns=n_cols
|
|
189
|
-
),
|
|
190
|
-
)
|
|
191
|
-
.properties(**kwargs)
|
|
192
|
-
)
|
|
193
|
-
plot.show()
|
|
194
|
-
return plot, combined
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
def compare_boxplot(adata, sim, var_names=None, max_plot=20, **kwargs):
|
|
198
|
-
if var_names is None:
|
|
199
|
-
var_names = adata.var_names[:max_plot]
|
|
200
|
-
|
|
201
|
-
combined = merge_samples(adata[:, var_names], sim.sample()[:, var_names])
|
|
202
|
-
alt.data_transformers.enable("vegafusion")
|
|
203
|
-
|
|
204
|
-
plot = (
|
|
205
|
-
alt.Chart(combined)
|
|
206
|
-
.mark_boxplot(extent="min-max")
|
|
207
|
-
.encode(
|
|
208
|
-
x=alt.X("value:Q").scale(zero=False),
|
|
209
|
-
y=alt.Y(
|
|
210
|
-
"variable:N",
|
|
211
|
-
sort=alt.EncodingSortField("mid_box_value", order="descending"),
|
|
212
|
-
),
|
|
213
|
-
facet="source:N",
|
|
214
|
-
)
|
|
215
|
-
.properties(**kwargs)
|
|
216
|
-
)
|
|
217
|
-
plot.show()
|
|
218
|
-
return plot, combined
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
def compare_histogram(adata, sim, var_names=None, max_plot=20, n_cols=5, **kwargs):
|
|
222
|
-
if var_names is None:
|
|
223
|
-
var_names = adata.var_names[:max_plot]
|
|
224
|
-
|
|
225
|
-
combined = merge_samples(adata[:, var_names], sim.sample()[:, var_names])
|
|
226
|
-
alt.data_transformers.enable("vegafusion")
|
|
227
|
-
|
|
228
|
-
plot = (
|
|
229
|
-
alt.Chart(combined)
|
|
230
|
-
.mark_bar(opacity=0.7)
|
|
231
|
-
.encode(
|
|
232
|
-
x=alt.X("value:Q").bin(maxbins=20),
|
|
233
|
-
y=alt.Y("count()").stack(None),
|
|
234
|
-
color="source:N",
|
|
235
|
-
facet=alt.Facet(
|
|
236
|
-
"variable", sort=alt.EncodingSortField("bin_maxbins_20_value")
|
|
237
|
-
),
|
|
238
|
-
)
|
|
239
|
-
.properties(**kwargs)
|
|
240
|
-
)
|
|
241
|
-
plot.show()
|
|
242
|
-
return plot, combined
|
|
@@ -1,32 +0,0 @@
|
|
|
1
|
-
from .negbin import negbin_regression, negbin_copula, fast_negbin_copula_factory
|
|
2
|
-
from .gaussian_copula_factory import group_indices, fast_copula_covariance, FastCovarianceStructure, fast_gaussian_copula_array_factory
|
|
3
|
-
from .poisson import poisson_regression, poisson_copula, fast_poisson_copula_factory
|
|
4
|
-
from .bernoulli import bernoulli_regression, bernoulli_copula
|
|
5
|
-
from .gaussian import gaussian_regression, gaussian_copula
|
|
6
|
-
from .zero_inflated_negbin import (
|
|
7
|
-
zero_inflated_negbin_regression,
|
|
8
|
-
zero_inflated_negbin_copula,
|
|
9
|
-
)
|
|
10
|
-
from .zero_inflated_poisson import zero_inflated_poisson_regression
|
|
11
|
-
from .glm_factory import multiple_formula_regression_factory
|
|
12
|
-
|
|
13
|
-
__all__ = [
|
|
14
|
-
"bernoulli_regression",
|
|
15
|
-
"bernoulli_copula",
|
|
16
|
-
"negbin_copula",
|
|
17
|
-
"negbin_regression",
|
|
18
|
-
"gaussian_regression",
|
|
19
|
-
"gaussian_copula",
|
|
20
|
-
"group_indices",
|
|
21
|
-
"poisson_copula",
|
|
22
|
-
"poisson_regression",
|
|
23
|
-
"zero_inflated_negbin_copula",
|
|
24
|
-
"zero_inflated_negbin_regression",
|
|
25
|
-
"zero_inflated_poisson_regression",
|
|
26
|
-
"multiple_formula_regression_factory",
|
|
27
|
-
"fast_copula_covariance",
|
|
28
|
-
"FastCovarianceStructure",
|
|
29
|
-
"fast_gaussian_copula_array_factory",
|
|
30
|
-
"fast_negbin_copula_factory",
|
|
31
|
-
"fast_poisson_copula_factory",
|
|
32
|
-
]
|
|
@@ -1,85 +0,0 @@
|
|
|
1
|
-
import pandas as pd
|
|
2
|
-
from . import gaussian_copula_factory as gcf
|
|
3
|
-
from . import glm_factory as factory
|
|
4
|
-
from .. import data
|
|
5
|
-
from .. import format
|
|
6
|
-
from . import poisson as poi
|
|
7
|
-
from anndata import AnnData
|
|
8
|
-
from scipy.stats import bernoulli
|
|
9
|
-
import numpy as np
|
|
10
|
-
import torch
|
|
11
|
-
from typing import Union
|
|
12
|
-
|
|
13
|
-
###############################################################################
|
|
14
|
-
## Regression functions that operate on numpy arrays
|
|
15
|
-
###############################################################################
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def bernoulli_regression_likelihood(params, X_dict, y):
|
|
19
|
-
# get appropriate parameter shape
|
|
20
|
-
n_features = X_dict["mean"].shape[1]
|
|
21
|
-
n_outcomes = y.shape[1]
|
|
22
|
-
|
|
23
|
-
# compute the negative log likelihood
|
|
24
|
-
coef_mean = params.reshape(n_features, n_outcomes)
|
|
25
|
-
theta = torch.sigmoid(X_dict["mean"] @ coef_mean)
|
|
26
|
-
log_likelihood = y * torch.log(theta) + (1 - y) * torch.log(1 - theta)
|
|
27
|
-
return -torch.sum(log_likelihood)
|
|
28
|
-
|
|
29
|
-
def bernoulli_initializer(X_dict, y, device):
|
|
30
|
-
n_features = X_dict["mean"].shape[1]
|
|
31
|
-
n_outcomes = y.shape[1]
|
|
32
|
-
return torch.zeros(n_features * n_outcomes, requires_grad=True, device=device)
|
|
33
|
-
|
|
34
|
-
def bernoulli_postprocessor(params, X_dict, y):
|
|
35
|
-
coef_mean = format.to_np(params).reshape(X_dict["mean"].shape[1], y.shape[1])
|
|
36
|
-
return {"coef_mean": coef_mean}
|
|
37
|
-
|
|
38
|
-
bernoulli_regression_array = factory.multiple_formula_regression_factory(
|
|
39
|
-
bernoulli_regression_likelihood, bernoulli_initializer, bernoulli_postprocessor
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
###############################################################################
|
|
43
|
-
## Regression functions that operate on AnnData objects
|
|
44
|
-
###############################################################################
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def bernoulli_regression(
|
|
48
|
-
adata: AnnData, formula: Union[str, dict], chunk_size: int = int(1e4), batch_size=512, **kwargs
|
|
49
|
-
) -> dict:
|
|
50
|
-
formula = data.standardize_formula(formula, allowed_keys=['mean'])
|
|
51
|
-
loaders = data.multiple_formula_loader(
|
|
52
|
-
adata, formula, chunk_size=chunk_size, batch_size=batch_size
|
|
53
|
-
)
|
|
54
|
-
parameters = bernoulli_regression_array(loaders, **kwargs)
|
|
55
|
-
return poi.format_poisson_parameters(
|
|
56
|
-
parameters, list(adata.var_names), list(loaders["mean"].dataset.x_names)
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
###############################################################################
|
|
61
|
-
## Copula versions for bernoulli regression
|
|
62
|
-
###############################################################################
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
def bernoulli_uniformizer(parameters, X_dict, y, epsilon=1e-3, random_seed=42):
|
|
66
|
-
np.random.seed(random_seed)
|
|
67
|
-
theta = torch.sigmoid(torch.from_numpy(X_dict["mean"] @ parameters["coef_mean"])).numpy()
|
|
68
|
-
u1 = bernoulli(theta).cdf(y)
|
|
69
|
-
u2 = np.where(y > 0, bernoulli(theta).cdf(y - 1), 0)
|
|
70
|
-
v = np.random.uniform(size=y.shape)
|
|
71
|
-
return np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
|
|
72
|
-
|
|
73
|
-
def format_bernoulli_parameters_with_loaders(parameters: dict, var_names: list, dls: dict) -> dict:
|
|
74
|
-
coef_mean_index = dls["mean"].dataset.x_names
|
|
75
|
-
|
|
76
|
-
parameters["coef_mean"] = pd.DataFrame(
|
|
77
|
-
parameters["coef_mean"], columns=var_names, index=coef_mean_index
|
|
78
|
-
)
|
|
79
|
-
return parameters
|
|
80
|
-
|
|
81
|
-
bernoulli_copula = gcf.gaussian_copula_factory(
|
|
82
|
-
gcf.gaussian_copula_array_factory(bernoulli_regression_array, bernoulli_uniformizer),
|
|
83
|
-
format_bernoulli_parameters_with_loaders,
|
|
84
|
-
param_name=['mean']
|
|
85
|
-
)
|
|
@@ -1,121 +0,0 @@
|
|
|
1
|
-
from . import gaussian_copula_factory as gcf
|
|
2
|
-
from . import glm_factory as factory
|
|
3
|
-
from .. import format
|
|
4
|
-
from .. import data
|
|
5
|
-
from anndata import AnnData
|
|
6
|
-
from scipy.stats import norm
|
|
7
|
-
import numpy as np
|
|
8
|
-
import pandas as pd
|
|
9
|
-
import torch
|
|
10
|
-
from typing import Union
|
|
11
|
-
|
|
12
|
-
###############################################################################
|
|
13
|
-
## Regression functions that operate on numpy arrays
|
|
14
|
-
###############################################################################
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
# Gaussian regression likelihood: regression onto mean and sdev
|
|
19
|
-
def gaussian_regression_likelihood(params, X_dict, y):
|
|
20
|
-
n_mean_features = X_dict["mean"].shape[1]
|
|
21
|
-
n_sdev_features = X_dict["sdev"].shape[1]
|
|
22
|
-
n_outcomes = y.shape[1]
|
|
23
|
-
|
|
24
|
-
coef_mean = params[: n_mean_features * n_outcomes].reshape(n_mean_features, n_outcomes)
|
|
25
|
-
coef_sdev = params[n_mean_features * n_outcomes :].reshape(n_sdev_features, n_outcomes)
|
|
26
|
-
mu = X_dict["mean"] @ coef_mean
|
|
27
|
-
sigma = torch.exp(X_dict["sdev"] @ coef_sdev)
|
|
28
|
-
|
|
29
|
-
# Negative log likelihood for Gaussian
|
|
30
|
-
log_likelihood = -0.5 * (torch.log(2 * torch.pi * sigma ** 2) + ((y - mu) ** 2) / (sigma ** 2))
|
|
31
|
-
return torch.sum(log_likelihood)
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def gaussian_initializer(x_dict, y, device):
|
|
36
|
-
n_mean_features = x_dict["mean"].shape[1]
|
|
37
|
-
n_outcomes = y.shape[1]
|
|
38
|
-
n_sdev_features = x_dict["sdev"].shape[1]
|
|
39
|
-
return torch.zeros(
|
|
40
|
-
n_mean_features * n_outcomes + n_sdev_features * n_outcomes,
|
|
41
|
-
requires_grad=True, device=device
|
|
42
|
-
)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def gaussian_postprocessor(params, x_dict, y):
|
|
47
|
-
n_mean_features = x_dict["mean"].shape[1]
|
|
48
|
-
n_outcomes = y.shape[1]
|
|
49
|
-
n_sdev_features = x_dict["sdev"].shape[1]
|
|
50
|
-
coef_mean = format.to_np(params[:n_mean_features * n_outcomes]).reshape(n_mean_features, n_outcomes)
|
|
51
|
-
coef_sdev = format.to_np(params[n_mean_features * n_outcomes:]).reshape(n_sdev_features, n_outcomes)
|
|
52
|
-
return {"coef_mean": coef_mean, "coef_sdev": coef_sdev}
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
gaussian_regression_array = factory.multiple_formula_regression_factory(
|
|
57
|
-
gaussian_regression_likelihood, gaussian_initializer, gaussian_postprocessor
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
###############################################################################
|
|
62
|
-
## Regression functions that operate on AnnData objects
|
|
63
|
-
###############################################################################
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def format_gaussian_parameters(
|
|
67
|
-
parameters: dict, var_names: list, mean_coef_index: list, sdev_coef_index: list
|
|
68
|
-
) -> dict:
|
|
69
|
-
parameters["coef_mean"] = pd.DataFrame(
|
|
70
|
-
parameters["coef_mean"], columns=var_names, index=mean_coef_index
|
|
71
|
-
)
|
|
72
|
-
parameters["coef_sdev"] = pd.DataFrame(
|
|
73
|
-
parameters["coef_sdev"], columns=var_names, index=sdev_coef_index
|
|
74
|
-
)
|
|
75
|
-
return parameters
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def format_gaussian_parameters_with_loaders(
|
|
79
|
-
parameters: dict, var_names: list, dls: dict
|
|
80
|
-
) -> dict:
|
|
81
|
-
mean_coef_index = dls["mean"].dataset.x_names
|
|
82
|
-
sdev_coef_index = dls["sdev"].dataset.x_names
|
|
83
|
-
return format_gaussian_parameters(
|
|
84
|
-
parameters, var_names, mean_coef_index, sdev_coef_index
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def gaussian_regression(
|
|
89
|
-
adata: AnnData, formula: Union[str, dict], chunk_size: int = int(1e4),
|
|
90
|
-
batch_size=512, **kwargs
|
|
91
|
-
) -> dict:
|
|
92
|
-
formula = data.standardize_formula(formula, allowed_keys=['mean', 'sdev'])
|
|
93
|
-
loaders = data.multiple_formula_loader(
|
|
94
|
-
adata, formula, chunk_size=chunk_size, batch_size=batch_size
|
|
95
|
-
)
|
|
96
|
-
parameters = gaussian_regression_array(loaders, **kwargs)
|
|
97
|
-
return format_gaussian_parameters(
|
|
98
|
-
parameters, list(adata.var_names), loaders["mean"].dataset.x_names,
|
|
99
|
-
loaders["sdev"].dataset.x_names
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
###############################################################################
|
|
104
|
-
## Copula versions for gaussian regression
|
|
105
|
-
###############################################################################
|
|
106
|
-
|
|
107
|
-
def gaussian_uniformizer(parameters, X_dict, y, epsilon=1e-3):
|
|
108
|
-
mu = X_dict["mean"] @ parameters["coef_mean"]
|
|
109
|
-
sigma = np.exp(X_dict["sdev"] @ parameters["coef_sdev"])
|
|
110
|
-
u = norm.cdf(y, loc=mu, scale=sigma)
|
|
111
|
-
u = np.clip(u, epsilon, 1 - epsilon)
|
|
112
|
-
return u
|
|
113
|
-
|
|
114
|
-
gaussian_copula_array = gcf.gaussian_copula_array_factory(
|
|
115
|
-
gaussian_regression_array, gaussian_uniformizer
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
gaussian_copula = gcf.gaussian_copula_factory(
|
|
119
|
-
gaussian_copula_array, format_gaussian_parameters_with_loaders,
|
|
120
|
-
param_name=['mean', 'sdev']
|
|
121
|
-
)
|