scdesigner 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scdesigner might be problematic. Click here for more details.

Files changed (66) hide show
  1. scdesigner/__init__.py +0 -0
  2. scdesigner/data/__init__.py +16 -0
  3. scdesigner/data/formula.py +137 -0
  4. scdesigner/data/group.py +123 -0
  5. scdesigner/data/sparse.py +39 -0
  6. scdesigner/diagnose/__init__.py +65 -0
  7. scdesigner/diagnose/aic_bic.py +119 -0
  8. scdesigner/diagnose/plot.py +242 -0
  9. scdesigner/estimators/__init__.py +27 -0
  10. scdesigner/estimators/bernoulli.py +85 -0
  11. scdesigner/estimators/gaussian.py +121 -0
  12. scdesigner/estimators/gaussian_copula_factory.py +152 -0
  13. scdesigner/estimators/glm_factory.py +75 -0
  14. scdesigner/estimators/negbin.py +129 -0
  15. scdesigner/estimators/pnmf.py +160 -0
  16. scdesigner/estimators/poisson.py +100 -0
  17. scdesigner/estimators/zero_inflated_negbin.py +195 -0
  18. scdesigner/estimators/zero_inflated_poisson.py +85 -0
  19. scdesigner/format/__init__.py +4 -0
  20. scdesigner/format/format.py +20 -0
  21. scdesigner/format/print.py +30 -0
  22. scdesigner/minimal/__init__.py +17 -0
  23. scdesigner/minimal/bernoulli.py +61 -0
  24. scdesigner/minimal/composite.py +119 -0
  25. scdesigner/minimal/copula.py +33 -0
  26. scdesigner/minimal/formula.py +23 -0
  27. scdesigner/minimal/gaussian.py +65 -0
  28. scdesigner/minimal/kwargs.py +24 -0
  29. scdesigner/minimal/loader.py +166 -0
  30. scdesigner/minimal/marginal.py +140 -0
  31. scdesigner/minimal/negbin.py +73 -0
  32. scdesigner/minimal/positive_nonnegative_matrix_factorization.py +231 -0
  33. scdesigner/minimal/scd3.py +95 -0
  34. scdesigner/minimal/scd3_instances.py +50 -0
  35. scdesigner/minimal/simulator.py +25 -0
  36. scdesigner/minimal/standard_covariance.py +124 -0
  37. scdesigner/minimal/transform.py +145 -0
  38. scdesigner/minimal/zero_inflated_negbin.py +86 -0
  39. scdesigner/predictors/__init__.py +15 -0
  40. scdesigner/predictors/bernoulli.py +9 -0
  41. scdesigner/predictors/gaussian.py +16 -0
  42. scdesigner/predictors/negbin.py +17 -0
  43. scdesigner/predictors/poisson.py +12 -0
  44. scdesigner/predictors/zero_inflated_negbin.py +18 -0
  45. scdesigner/predictors/zero_inflated_poisson.py +18 -0
  46. scdesigner/samplers/__init__.py +23 -0
  47. scdesigner/samplers/bernoulli.py +27 -0
  48. scdesigner/samplers/gaussian.py +25 -0
  49. scdesigner/samplers/glm_factory.py +41 -0
  50. scdesigner/samplers/negbin.py +25 -0
  51. scdesigner/samplers/poisson.py +25 -0
  52. scdesigner/samplers/zero_inflated_negbin.py +40 -0
  53. scdesigner/samplers/zero_inflated_poisson.py +16 -0
  54. scdesigner/simulators/__init__.py +31 -0
  55. scdesigner/simulators/composite_regressor.py +72 -0
  56. scdesigner/simulators/glm_simulator.py +167 -0
  57. scdesigner/simulators/pnmf_regression.py +61 -0
  58. scdesigner/transform/__init__.py +7 -0
  59. scdesigner/transform/amplify.py +14 -0
  60. scdesigner/transform/mask.py +33 -0
  61. scdesigner/transform/nullify.py +25 -0
  62. scdesigner/transform/split.py +23 -0
  63. scdesigner/transform/substitute.py +14 -0
  64. scdesigner-0.0.1.dist-info/METADATA +23 -0
  65. scdesigner-0.0.1.dist-info/RECORD +66 -0
  66. scdesigner-0.0.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,242 @@
1
+ import scanpy as sc
2
+ import numpy as np
3
+ import pandas as pd
4
+ import altair as alt
5
+ import matplotlib.pyplot as plt
6
+
7
+
8
+ def adata_df(adata):
9
+ return (
10
+ pd.DataFrame(adata.X, columns=adata.var_names)
11
+ .melt(id_vars=[], value_vars=adata.var_names)
12
+ .reset_index(drop=True)
13
+ )
14
+
15
+
16
+ def merge_samples(adata, sim):
17
+ source = adata_df(adata)
18
+ simulated = adata_df(sim)
19
+ return pd.concat(
20
+ {"real": source, "simulated": simulated}, names=["source"]
21
+ ).reset_index(level="source")
22
+
23
+
24
+ def plot_umap(
25
+ adata,
26
+ color=None,
27
+ shape=None,
28
+ facet=None,
29
+ opacity=0.6,
30
+ n_comps=20,
31
+ n_neighbors=15,
32
+ transform=lambda x: np.log1p(x),
33
+ **kwargs
34
+ ):
35
+ mapping = {"x": "UMAP1", "y": "UMAP2", "color": color, "shape": shape}
36
+ mapping = {k: v for k, v in mapping.items() if v is not None}
37
+
38
+ adata_ = adata.copy()
39
+ adata_.X = check_sparse(adata_.X)
40
+ Z = transform(adata_.X)
41
+ if Z.shape[1] == adata_.X.shape[1]:
42
+ adata_.X = transform(adata_.X)
43
+ else:
44
+ adata_ = adata_[:, : Z.shape[1]]
45
+ adata_.X = Z
46
+ adata_.var_names = [f"transform_{k}" for k in range(Z.shape[1])]
47
+
48
+ # umap on the top PCA dimensions
49
+ sc.pp.pca(adata_, n_comps=n_comps)
50
+ sc.pp.neighbors(adata_, n_neighbors=n_neighbors, n_pcs=n_comps)
51
+ sc.tl.umap(adata_, **kwargs)
52
+
53
+ # get umap embeddings
54
+ umap_df = pd.DataFrame(adata_.obsm["X_umap"], columns=["UMAP1", "UMAP2"])
55
+ umap_df = pd.concat([umap_df, adata_.obs.reset_index(drop=True)], axis=1)
56
+
57
+ # encode and visualize
58
+ chart = alt.Chart(umap_df).mark_point(opacity=opacity).encode(**mapping)
59
+ if facet is not None:
60
+ chart = chart.facet(column=alt.Facet(facet))
61
+ return chart
62
+
63
+
64
+ def plot_pca(
65
+ adata,
66
+ color=None,
67
+ shape=None,
68
+ facet=None,
69
+ opacity=0.6,
70
+ plot_dims=[0, 1],
71
+ transform=lambda x: np.log1p(x),
72
+ **kwargs
73
+ ):
74
+ mapping = {"x": "PCA1", "y": "PCA2", "color": color, "shape": shape}
75
+ mapping = {k: v for k, v in mapping.items() if v is not None}
76
+
77
+ adata_ = adata.copy()
78
+ adata_.X = check_sparse(adata_.X)
79
+ adata_.X = transform(adata_.X)
80
+
81
+ # get PCA scores
82
+ sc.pp.pca(adata_, **kwargs)
83
+ pca_df = pd.DataFrame(adata_.obsm["X_pca"][:, plot_dims], columns=["PCA1", "PCA2"])
84
+ pca_df = pd.concat([pca_df, adata_.obs.reset_index(drop=True)], axis=1)
85
+
86
+ # plot
87
+ chart = alt.Chart(pca_df).mark_point(opacity=opacity).encode(**mapping)
88
+ if facet is not None:
89
+ chart = chart.facet(column=alt.Facet(facet))
90
+ return chart
91
+
92
+
93
+ def compare_summary(real, simulated, summary_fun):
94
+ df = pd.DataFrame({"real": summary_fun(real), "simulated": summary_fun(simulated)})
95
+
96
+ identity = pd.DataFrame(
97
+ {
98
+ "real": [df["real"].min(), df["real"].max()],
99
+ "simulated": [df["real"].min(), df["real"].max()],
100
+ }
101
+ )
102
+ return alt.Chart(identity).mark_line(color="#dedede").encode(
103
+ x="real", y="simulated"
104
+ ) + alt.Chart(df).mark_circle().encode(x="real", y="simulated")
105
+
106
+
107
+ def check_sparse(X):
108
+ if not isinstance(X, np.ndarray):
109
+ X = X.todense()
110
+ return X
111
+
112
+
113
+ def compare_means(real, simulated, transform=lambda x: x):
114
+ real_, simulated_ = prepare_dense(real, simulated)
115
+ summary = lambda a: np.asarray(transform(a.X).mean(axis=0)).flatten()
116
+ return compare_summary(real_, simulated_, summary)
117
+
118
+
119
+ def prepare_dense(real, simulated):
120
+ real_ = real.copy()
121
+ simulated_ = simulated.copy()
122
+ real_.X = check_sparse(real_.X)
123
+ simulated_.X = check_sparse(simulated_.X)
124
+ return real_, simulated_
125
+
126
+
127
+ def compare_variances(real, simulated, transform=lambda x: x):
128
+ real_, simulated_ = prepare_dense(real, simulated)
129
+ summary = lambda a: np.asarray(np.var(transform(a.X), axis=0)).flatten()
130
+ return compare_summary(real_, simulated_, summary)
131
+
132
+
133
+ def compare_standard_deviation(real, simulated, transform=lambda x: x):
134
+ real_, simulated_ = prepare_dense(real, simulated)
135
+ summary = lambda a: np.asarray(np.std(transform(a.X), axis=0)).flatten()
136
+ return compare_summary(real_, simulated_, summary)
137
+
138
+
139
+ def concat_real_sim(real, simulated):
140
+ real_, simulated_ = prepare_dense(real, simulated)
141
+ real_.obs["source"] = "real"
142
+ simulated_.obs["source"] = "simulated"
143
+ return real_.concatenate(simulated_, join="outer", batch_key=None)
144
+
145
+
146
+ def compare_umap(real, simulated, transform=lambda x: x, **kwargs):
147
+ adata = concat_real_sim(real, simulated)
148
+ return plot_umap(adata, facet="source", transform=transform, **kwargs)
149
+
150
+
151
+ def compare_pca(real, simulated, transform=lambda x: x, **kwargs):
152
+ adata = concat_real_sim(real, simulated)
153
+ return plot_pca(adata, facet="source", transform=transform, **kwargs)
154
+
155
+
156
+ def plot_hist(sim_data, real_data, idx):
157
+ sim = sim_data[:, idx]
158
+ real = real_data[:, idx]
159
+ b = np.linspace(min(min(sim), min(real)), max(max(sim), max(real)), 50)
160
+
161
+ plt.hist([real, sim], b, label=["Real", "Simulated"], histtype="bar")
162
+ plt.xlabel("x")
163
+ plt.ylabel("Density")
164
+ plt.legend()
165
+ plt.show()
166
+
167
+
168
+ def compare_ecdf(adata, sim, var_names=None, max_plot=10, n_cols=5, **kwargs):
169
+ if var_names is None:
170
+ var_names = adata.var_names[:max_plot]
171
+
172
+ combined = merge_samples(adata[:, var_names], sim.sample()[:, var_names])
173
+ alt.data_transformers.enable("vegafusion")
174
+
175
+ plot = (
176
+ alt.Chart(combined)
177
+ .transform_window(
178
+ ecdf="cume_dist()", sort=[{"field": "value"}], groupby=["variable"]
179
+ )
180
+ .mark_line(
181
+ interpolate="step-after",
182
+ )
183
+ .encode(
184
+ x="value:Q",
185
+ y="ecdf:Q",
186
+ color="source:N",
187
+ facet=alt.Facet(
188
+ "variable", sort=alt.EncodingSortField("value"), columns=n_cols
189
+ ),
190
+ )
191
+ .properties(**kwargs)
192
+ )
193
+ plot.show()
194
+ return plot, combined
195
+
196
+
197
+ def compare_boxplot(adata, sim, var_names=None, max_plot=20, **kwargs):
198
+ if var_names is None:
199
+ var_names = adata.var_names[:max_plot]
200
+
201
+ combined = merge_samples(adata[:, var_names], sim.sample()[:, var_names])
202
+ alt.data_transformers.enable("vegafusion")
203
+
204
+ plot = (
205
+ alt.Chart(combined)
206
+ .mark_boxplot(extent="min-max")
207
+ .encode(
208
+ x=alt.X("value:Q").scale(zero=False),
209
+ y=alt.Y(
210
+ "variable:N",
211
+ sort=alt.EncodingSortField("mid_box_value", order="descending"),
212
+ ),
213
+ facet="source:N",
214
+ )
215
+ .properties(**kwargs)
216
+ )
217
+ plot.show()
218
+ return plot, combined
219
+
220
+
221
+ def compare_histogram(adata, sim, var_names=None, max_plot=20, n_cols=5, **kwargs):
222
+ if var_names is None:
223
+ var_names = adata.var_names[:max_plot]
224
+
225
+ combined = merge_samples(adata[:, var_names], sim.sample()[:, var_names])
226
+ alt.data_transformers.enable("vegafusion")
227
+
228
+ plot = (
229
+ alt.Chart(combined)
230
+ .mark_bar(opacity=0.7)
231
+ .encode(
232
+ x=alt.X("value:Q").bin(maxbins=20),
233
+ y=alt.Y("count()").stack(None),
234
+ color="source:N",
235
+ facet=alt.Facet(
236
+ "variable", sort=alt.EncodingSortField("bin_maxbins_20_value")
237
+ ),
238
+ )
239
+ .properties(**kwargs)
240
+ )
241
+ plot.show()
242
+ return plot, combined
@@ -0,0 +1,27 @@
1
+ from .negbin import negbin_regression, negbin_copula
2
+ from .gaussian_copula_factory import group_indices
3
+ from .poisson import poisson_regression, poisson_copula
4
+ from .bernoulli import bernoulli_regression, bernoulli_copula
5
+ from .gaussian import gaussian_regression, gaussian_copula
6
+ from .zero_inflated_negbin import (
7
+ zero_inflated_negbin_regression,
8
+ zero_inflated_negbin_copula,
9
+ )
10
+ from .zero_inflated_poisson import zero_inflated_poisson_regression
11
+ from .glm_factory import multiple_formula_regression_factory
12
+
13
+ __all__ = [
14
+ "bernoulli_regression",
15
+ "bernoulli_copula",
16
+ "negbin_copula",
17
+ "negbin_regression",
18
+ "gaussian_regression",
19
+ "gaussian_copula",
20
+ "group_indices",
21
+ "poisson_copula",
22
+ "poisson_regression",
23
+ "zero_inflated_negbin_copula",
24
+ "zero_inflated_negbin_regression",
25
+ "zero_inflated_poisson_regression",
26
+ "multiple_formula_regression_factory",
27
+ ]
@@ -0,0 +1,85 @@
1
+ import pandas as pd
2
+ from . import gaussian_copula_factory as gcf
3
+ from . import glm_factory as factory
4
+ from .. import data
5
+ from .. import format
6
+ from . import poisson as poi
7
+ from anndata import AnnData
8
+ from scipy.stats import bernoulli
9
+ import numpy as np
10
+ import torch
11
+ from typing import Union
12
+
13
+ ###############################################################################
14
+ ## Regression functions that operate on numpy arrays
15
+ ###############################################################################
16
+
17
+
18
+ def bernoulli_regression_likelihood(params, X_dict, y):
19
+ # get appropriate parameter shape
20
+ n_features = X_dict["mean"].shape[1]
21
+ n_outcomes = y.shape[1]
22
+
23
+ # compute the negative log likelihood
24
+ coef_mean = params.reshape(n_features, n_outcomes)
25
+ theta = torch.sigmoid(X_dict["mean"] @ coef_mean)
26
+ log_likelihood = y * torch.log(theta) + (1 - y) * torch.log(1 - theta)
27
+ return -torch.sum(log_likelihood)
28
+
29
+ def bernoulli_initializer(X_dict, y, device):
30
+ n_features = X_dict["mean"].shape[1]
31
+ n_outcomes = y.shape[1]
32
+ return torch.zeros(n_features * n_outcomes, requires_grad=True, device=device)
33
+
34
+ def bernoulli_postprocessor(params, X_dict, y):
35
+ coef_mean = format.to_np(params).reshape(X_dict["mean"].shape[1], y.shape[1])
36
+ return {"coef_mean": coef_mean}
37
+
38
+ bernoulli_regression_array = factory.multiple_formula_regression_factory(
39
+ bernoulli_regression_likelihood, bernoulli_initializer, bernoulli_postprocessor
40
+ )
41
+
42
+ ###############################################################################
43
+ ## Regression functions that operate on AnnData objects
44
+ ###############################################################################
45
+
46
+
47
+ def bernoulli_regression(
48
+ adata: AnnData, formula: Union[str, dict], chunk_size: int = int(1e4), batch_size=512, **kwargs
49
+ ) -> dict:
50
+ formula = data.standardize_formula(formula, allowed_keys=['mean'])
51
+ loaders = data.multiple_formula_loader(
52
+ adata, formula, chunk_size=chunk_size, batch_size=batch_size
53
+ )
54
+ parameters = bernoulli_regression_array(loaders, **kwargs)
55
+ return poi.format_poisson_parameters(
56
+ parameters, list(adata.var_names), list(loaders["mean"].dataset.x_names)
57
+ )
58
+
59
+
60
+ ###############################################################################
61
+ ## Copula versions for bernoulli regression
62
+ ###############################################################################
63
+
64
+
65
+ def bernoulli_uniformizer(parameters, X_dict, y, epsilon=1e-3, random_seed=42):
66
+ np.random.seed(random_seed)
67
+ theta = torch.sigmoid(torch.from_numpy(X_dict["mean"] @ parameters["coef_mean"])).numpy()
68
+ u1 = bernoulli(theta).cdf(y)
69
+ u2 = np.where(y > 0, bernoulli(theta).cdf(y - 1), 0)
70
+ v = np.random.uniform(size=y.shape)
71
+ return np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
72
+
73
+ def format_bernoulli_parameters_with_loaders(parameters: dict, var_names: list, dls: dict) -> dict:
74
+ coef_mean_index = dls["mean"].dataset.x_names
75
+
76
+ parameters["coef_mean"] = pd.DataFrame(
77
+ parameters["coef_mean"], columns=var_names, index=coef_mean_index
78
+ )
79
+ return parameters
80
+
81
+ bernoulli_copula = gcf.gaussian_copula_factory(
82
+ gcf.gaussian_copula_array_factory(bernoulli_regression_array, bernoulli_uniformizer),
83
+ format_bernoulli_parameters_with_loaders,
84
+ param_name=['mean']
85
+ )
@@ -0,0 +1,121 @@
1
+ from . import gaussian_copula_factory as gcf
2
+ from . import glm_factory as factory
3
+ from .. import format
4
+ from .. import data
5
+ from anndata import AnnData
6
+ from scipy.stats import norm
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from typing import Union
11
+
12
+ ###############################################################################
13
+ ## Regression functions that operate on numpy arrays
14
+ ###############################################################################
15
+
16
+
17
+
18
+ # Gaussian regression likelihood: regression onto mean and sdev
19
+ def gaussian_regression_likelihood(params, X_dict, y):
20
+ n_mean_features = X_dict["mean"].shape[1]
21
+ n_sdev_features = X_dict["sdev"].shape[1]
22
+ n_outcomes = y.shape[1]
23
+
24
+ coef_mean = params[: n_mean_features * n_outcomes].reshape(n_mean_features, n_outcomes)
25
+ coef_sdev = params[n_mean_features * n_outcomes :].reshape(n_sdev_features, n_outcomes)
26
+ mu = X_dict["mean"] @ coef_mean
27
+ sigma = torch.exp(X_dict["sdev"] @ coef_sdev)
28
+
29
+ # Negative log likelihood for Gaussian
30
+ log_likelihood = -0.5 * (torch.log(2 * torch.pi * sigma ** 2) + ((y - mu) ** 2) / (sigma ** 2))
31
+ return torch.sum(log_likelihood)
32
+
33
+
34
+
35
+ def gaussian_initializer(x_dict, y, device):
36
+ n_mean_features = x_dict["mean"].shape[1]
37
+ n_outcomes = y.shape[1]
38
+ n_sdev_features = x_dict["sdev"].shape[1]
39
+ return torch.zeros(
40
+ n_mean_features * n_outcomes + n_sdev_features * n_outcomes,
41
+ requires_grad=True, device=device
42
+ )
43
+
44
+
45
+
46
+ def gaussian_postprocessor(params, x_dict, y):
47
+ n_mean_features = x_dict["mean"].shape[1]
48
+ n_outcomes = y.shape[1]
49
+ n_sdev_features = x_dict["sdev"].shape[1]
50
+ coef_mean = format.to_np(params[:n_mean_features * n_outcomes]).reshape(n_mean_features, n_outcomes)
51
+ coef_sdev = format.to_np(params[n_mean_features * n_outcomes:]).reshape(n_sdev_features, n_outcomes)
52
+ return {"coef_mean": coef_mean, "coef_sdev": coef_sdev}
53
+
54
+
55
+
56
+ gaussian_regression_array = factory.multiple_formula_regression_factory(
57
+ gaussian_regression_likelihood, gaussian_initializer, gaussian_postprocessor
58
+ )
59
+
60
+
61
+ ###############################################################################
62
+ ## Regression functions that operate on AnnData objects
63
+ ###############################################################################
64
+
65
+
66
+ def format_gaussian_parameters(
67
+ parameters: dict, var_names: list, mean_coef_index: list, sdev_coef_index: list
68
+ ) -> dict:
69
+ parameters["coef_mean"] = pd.DataFrame(
70
+ parameters["coef_mean"], columns=var_names, index=mean_coef_index
71
+ )
72
+ parameters["coef_sdev"] = pd.DataFrame(
73
+ parameters["coef_sdev"], columns=var_names, index=sdev_coef_index
74
+ )
75
+ return parameters
76
+
77
+
78
+ def format_gaussian_parameters_with_loaders(
79
+ parameters: dict, var_names: list, dls: dict
80
+ ) -> dict:
81
+ mean_coef_index = dls["mean"].dataset.x_names
82
+ sdev_coef_index = dls["sdev"].dataset.x_names
83
+ return format_gaussian_parameters(
84
+ parameters, var_names, mean_coef_index, sdev_coef_index
85
+ )
86
+
87
+
88
+ def gaussian_regression(
89
+ adata: AnnData, formula: Union[str, dict], chunk_size: int = int(1e4),
90
+ batch_size=512, **kwargs
91
+ ) -> dict:
92
+ formula = data.standardize_formula(formula, allowed_keys=['mean', 'sdev'])
93
+ loaders = data.multiple_formula_loader(
94
+ adata, formula, chunk_size=chunk_size, batch_size=batch_size
95
+ )
96
+ parameters = gaussian_regression_array(loaders, **kwargs)
97
+ return format_gaussian_parameters(
98
+ parameters, list(adata.var_names), loaders["mean"].dataset.x_names,
99
+ loaders["sdev"].dataset.x_names
100
+ )
101
+
102
+
103
+ ###############################################################################
104
+ ## Copula versions for gaussian regression
105
+ ###############################################################################
106
+
107
+ def gaussian_uniformizer(parameters, X_dict, y, epsilon=1e-3):
108
+ mu = X_dict["mean"] @ parameters["coef_mean"]
109
+ sigma = np.exp(X_dict["sdev"] @ parameters["coef_sdev"])
110
+ u = norm.cdf(y, loc=mu, scale=sigma)
111
+ u = np.clip(u, epsilon, 1 - epsilon)
112
+ return u
113
+
114
+ gaussian_copula_array = gcf.gaussian_copula_array_factory(
115
+ gaussian_regression_array, gaussian_uniformizer
116
+ )
117
+
118
+ gaussian_copula = gcf.gaussian_copula_factory(
119
+ gaussian_copula_array, format_gaussian_parameters_with_loaders,
120
+ param_name=['mean', 'sdev']
121
+ )
@@ -0,0 +1,152 @@
1
+ from ..data import stack_collate, multiple_formula_group_loader
2
+ from .. import data
3
+ from anndata import AnnData
4
+ from collections.abc import Callable
5
+ from typing import Union
6
+ from scipy.stats import norm
7
+ from torch.utils.data import DataLoader
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ ###############################################################################
12
+ ## General copula factory functions
13
+ ###############################################################################
14
+
15
+
16
+ def gaussian_copula_array_factory(marginal_model: Callable, uniformizer: Callable):
17
+ def copula_fun(loaders: dict[str, DataLoader], lr: float = 0.1, epochs: int = 40, **kwargs):
18
+ # for the marginal model, ignore the groupings
19
+ # Strip all dataloaders and create a dictionary to pass to marginal_model
20
+ formula_loaders = {}
21
+ for key in loaders.keys():
22
+ formula_loaders[key] = strip_dataloader(loaders[key], pop="Stack" in type(loaders[key].dataset).__name__)
23
+
24
+ # Call marginal_model with the dictionary of stripped dataloaders
25
+ parameters = marginal_model(formula_loaders, lr=lr, epochs=epochs, **kwargs)
26
+
27
+ # estimate covariance, allowing for different groups
28
+ parameters["covariance"] = copula_covariance(parameters, loaders, uniformizer)
29
+ return parameters
30
+
31
+ return copula_fun
32
+
33
+
34
+ def gaussian_copula_factory(copula_array_fun: Callable,
35
+ parameter_formatter: Callable,
36
+ param_name: list = None):
37
+ def copula_fun(
38
+ adata: AnnData,
39
+ formula: Union[str, dict] = "~ 1",
40
+ grouping_var: str = None,
41
+ chunk_size: int = int(1e4),
42
+ batch_size: int = 512,
43
+ **kwargs
44
+ ) -> dict:
45
+
46
+ if param_name is not None:
47
+ formula = data.standardize_formula(formula, param_name)
48
+
49
+ dls = multiple_formula_group_loader(
50
+ adata,
51
+ formula,
52
+ grouping_var,
53
+ chunk_size=chunk_size,
54
+ batch_size=batch_size,
55
+ ) # returns a dictionary of dataloaders
56
+ parameters = copula_array_fun(dls, **kwargs)
57
+
58
+ # Pass the full dls to parameter_formatter so it can extract what it needs
59
+ parameters = parameter_formatter(
60
+ parameters, adata.var_names, dls
61
+ )
62
+ parameters["covariance"] = format_copula_parameters(parameters, adata.var_names)
63
+ return parameters
64
+
65
+ return copula_fun
66
+
67
+
68
+ def copula_covariance(parameters: dict, loaders: dict[str, DataLoader], uniformizer: Callable):
69
+ first_loader = next(iter(loaders.values()))
70
+ D = next(iter(first_loader))[1].shape[1] #dimension of y
71
+ groups = first_loader.dataset.groups # a list of strings of group names
72
+ sums = {g: np.zeros(D) for g in groups}
73
+ second_moments = {g: np.eye(D) for g in groups}
74
+ Ng = {g: 0 for g in groups}
75
+ keys = list(loaders.keys())
76
+ loaders = list(loaders.values())
77
+ num_keys = len(keys)
78
+
79
+ for batches in zip(*loaders):
80
+ x_batch_dict = {
81
+ keys[i]: batches[i][0].cpu().numpy() for i in range(num_keys)
82
+ }
83
+ y_batch = batches[0][1].cpu().numpy()
84
+ memberships = batches[0][2] # should be identical for all keys
85
+
86
+ u = uniformizer(parameters, x_batch_dict, y_batch)
87
+ for g in groups:
88
+ ix = np.where(np.array(memberships) == g)
89
+ z = norm().ppf(u[ix])
90
+ second_moments[g] += z.T @ z
91
+ sums[g] += z.sum(axis=0)
92
+ Ng[g] += len(ix[0])
93
+
94
+ result = {}
95
+ for g in groups:
96
+ mean = sums[g] / Ng[g]
97
+ result[g] = second_moments[g] / Ng[g] - np.outer(mean, mean)
98
+
99
+ if len(groups) == 1:
100
+ return list(result.values())[0]
101
+ return result
102
+
103
+
104
+ ###############################################################################
105
+ ## Helpers to prepare and postprocess copula parameters
106
+ ###############################################################################
107
+
108
+
109
+ def group_indices(grouping_var: str, obs: pd.DataFrame) -> dict:
110
+ """
111
+ Returns a dictionary of group indices for each group in the grouping variable.
112
+ """
113
+ if grouping_var is None:
114
+ grouping_var = "_copula_group"
115
+ if "copula_group" not in obs.columns:
116
+ obs["_copula_group"] = pd.Categorical(["shared_group"] * len(obs))
117
+ result = {}
118
+
119
+ for group in list(obs[grouping_var].dtype.categories):
120
+ result[group] = np.where(obs[grouping_var].values == group)[0]
121
+ return result
122
+
123
+
124
+ def clip(u: np.array, min: float = 1e-5, max: float = 1 - 1e-5) -> np.array:
125
+ u[u < min] = min
126
+ u[u > max] = max
127
+ return u
128
+
129
+
130
+ def format_copula_parameters(parameters: dict, var_names: list):
131
+ covariance = parameters["covariance"]
132
+ if type(covariance) is not dict:
133
+ covariance = pd.DataFrame(
134
+ parameters["covariance"], columns=list(var_names), index=list(var_names)
135
+ )
136
+ else:
137
+ for group in covariance.keys():
138
+ covariance[group] = pd.DataFrame(
139
+ parameters["covariance"][group],
140
+ columns=list(var_names),
141
+ index=list(var_names),
142
+ )
143
+ return covariance
144
+
145
+
146
+ def strip_dataloader(dataloader, pop=False):
147
+ return DataLoader(
148
+ dataset=dataloader.dataset,
149
+ batch_sampler=dataloader.batch_sampler,
150
+ collate_fn=stack_collate(pop=pop, groups=False),
151
+ )
152
+
@@ -0,0 +1,75 @@
1
+ from tqdm import tqdm
2
+ from torch.utils.data import DataLoader
3
+ import torch
4
+
5
+
6
+ def glm_regression_factory(likelihood, initializer, postprocessor) -> dict:
7
+ def estimator(
8
+ dataloader: DataLoader,
9
+ lr: float = 0.1,
10
+ epochs: int = 40,
11
+ ):
12
+ device = check_device()
13
+ x, y = next(iter(dataloader))
14
+ params = initializer(x, y, device)
15
+ optimizer = torch.optim.Adam([params], lr=lr)
16
+
17
+ for epoch in range(epochs):
18
+ for x_batch, y_batch in (pbar := tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}", leave=False)):
19
+ optimizer.zero_grad()
20
+ loss = likelihood(params, x_batch, y_batch)
21
+ loss.backward()
22
+ optimizer.step()
23
+ pbar.set_postfix_str(f"loss: {loss.item()}")
24
+
25
+
26
+ return postprocessor(params, x.shape[1], y.shape[1])
27
+
28
+ return estimator
29
+
30
+ def multiple_formula_regression_factory(likelihood, initializer, postprocessor) -> dict:
31
+ def estimator(
32
+ dataloaders: dict[str, DataLoader],
33
+ lr: float = 0.1,
34
+ epochs: int = 40,
35
+ ):
36
+ device = check_device()
37
+ x_dict = {}
38
+ y_dict = {}
39
+ for key in dataloaders.keys():
40
+ x_dict[key], y_dict[key] = next(iter(dataloaders[key]))
41
+ # check if all ys are the same
42
+ y_ref = y_dict[list(dataloaders.keys())[0]]
43
+ for key in dataloaders.keys():
44
+ if not torch.equal(y_dict[key], y_ref):
45
+ raise ValueError(f"Ys are not the same for {key}")
46
+ params = initializer(x_dict, y_ref, device) # x is a dictionary of tensors, y is a tensor
47
+ optimizer = torch.optim.Adam([params], lr=lr)
48
+
49
+ keys = list(dataloaders.keys())
50
+ loaders = list(dataloaders.values())
51
+
52
+ for epoch in range(epochs):
53
+ num_keys = len(keys)
54
+ for batches in (pbar := tqdm(zip(*loaders), desc=f"Epoch {epoch + 1}/{epochs}", leave=False)):
55
+ x_batch_dict = {
56
+ keys[i]: batches[i][0].to(device) for i in range(num_keys)
57
+ }
58
+ y_batch = batches[0][1].to(device)
59
+ optimizer.zero_grad()
60
+ loss = likelihood(params, x_batch_dict, y_batch)
61
+ loss.backward()
62
+ optimizer.step()
63
+ pbar.set_postfix_str(f"loss: {loss.item()}")
64
+
65
+ return postprocessor(params, x_dict, y_ref)
66
+
67
+ return estimator
68
+
69
+
70
+ def check_device():
71
+ return torch.device(
72
+ "cuda"
73
+ if torch.cuda.is_available()
74
+ else "mps" if torch.backends.mps.is_available() else "cpu"
75
+ )