scdesigner 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scdesigner might be problematic. Click here for more details.

Files changed (66) hide show
  1. scdesigner/__init__.py +0 -0
  2. scdesigner/data/__init__.py +16 -0
  3. scdesigner/data/formula.py +137 -0
  4. scdesigner/data/group.py +123 -0
  5. scdesigner/data/sparse.py +39 -0
  6. scdesigner/diagnose/__init__.py +65 -0
  7. scdesigner/diagnose/aic_bic.py +119 -0
  8. scdesigner/diagnose/plot.py +242 -0
  9. scdesigner/estimators/__init__.py +27 -0
  10. scdesigner/estimators/bernoulli.py +85 -0
  11. scdesigner/estimators/gaussian.py +121 -0
  12. scdesigner/estimators/gaussian_copula_factory.py +152 -0
  13. scdesigner/estimators/glm_factory.py +75 -0
  14. scdesigner/estimators/negbin.py +129 -0
  15. scdesigner/estimators/pnmf.py +160 -0
  16. scdesigner/estimators/poisson.py +100 -0
  17. scdesigner/estimators/zero_inflated_negbin.py +195 -0
  18. scdesigner/estimators/zero_inflated_poisson.py +85 -0
  19. scdesigner/format/__init__.py +4 -0
  20. scdesigner/format/format.py +20 -0
  21. scdesigner/format/print.py +30 -0
  22. scdesigner/minimal/__init__.py +17 -0
  23. scdesigner/minimal/bernoulli.py +61 -0
  24. scdesigner/minimal/composite.py +119 -0
  25. scdesigner/minimal/copula.py +33 -0
  26. scdesigner/minimal/formula.py +23 -0
  27. scdesigner/minimal/gaussian.py +65 -0
  28. scdesigner/minimal/kwargs.py +24 -0
  29. scdesigner/minimal/loader.py +166 -0
  30. scdesigner/minimal/marginal.py +140 -0
  31. scdesigner/minimal/negbin.py +73 -0
  32. scdesigner/minimal/positive_nonnegative_matrix_factorization.py +231 -0
  33. scdesigner/minimal/scd3.py +95 -0
  34. scdesigner/minimal/scd3_instances.py +50 -0
  35. scdesigner/minimal/simulator.py +25 -0
  36. scdesigner/minimal/standard_covariance.py +124 -0
  37. scdesigner/minimal/transform.py +145 -0
  38. scdesigner/minimal/zero_inflated_negbin.py +86 -0
  39. scdesigner/predictors/__init__.py +15 -0
  40. scdesigner/predictors/bernoulli.py +9 -0
  41. scdesigner/predictors/gaussian.py +16 -0
  42. scdesigner/predictors/negbin.py +17 -0
  43. scdesigner/predictors/poisson.py +12 -0
  44. scdesigner/predictors/zero_inflated_negbin.py +18 -0
  45. scdesigner/predictors/zero_inflated_poisson.py +18 -0
  46. scdesigner/samplers/__init__.py +23 -0
  47. scdesigner/samplers/bernoulli.py +27 -0
  48. scdesigner/samplers/gaussian.py +25 -0
  49. scdesigner/samplers/glm_factory.py +41 -0
  50. scdesigner/samplers/negbin.py +25 -0
  51. scdesigner/samplers/poisson.py +25 -0
  52. scdesigner/samplers/zero_inflated_negbin.py +40 -0
  53. scdesigner/samplers/zero_inflated_poisson.py +16 -0
  54. scdesigner/simulators/__init__.py +31 -0
  55. scdesigner/simulators/composite_regressor.py +72 -0
  56. scdesigner/simulators/glm_simulator.py +167 -0
  57. scdesigner/simulators/pnmf_regression.py +61 -0
  58. scdesigner/transform/__init__.py +7 -0
  59. scdesigner/transform/amplify.py +14 -0
  60. scdesigner/transform/mask.py +33 -0
  61. scdesigner/transform/nullify.py +25 -0
  62. scdesigner/transform/split.py +23 -0
  63. scdesigner/transform/substitute.py +14 -0
  64. scdesigner-0.0.1.dist-info/METADATA +23 -0
  65. scdesigner-0.0.1.dist-info/RECORD +66 -0
  66. scdesigner-0.0.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,129 @@
1
+ from . import gaussian_copula_factory as gcf
2
+ from . import glm_factory as factory
3
+ from .. import format
4
+ from .. import data
5
+ from anndata import AnnData
6
+ from scipy.stats import nbinom
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from typing import Union
11
+
12
+ ###############################################################################
13
+ ## Regression functions that operate on numpy arrays
14
+ ###############################################################################
15
+
16
+
17
+ def negbin_regression_likelihood(params, X_dict, y):
18
+ n_mean_features = X_dict["mean"].shape[1]
19
+ n_dispersion_features = X_dict["dispersion"].shape[1]
20
+ n_outcomes = y.shape[1]
21
+
22
+ # form the mean and dispersion parameters
23
+ coef_mean = params[: n_mean_features * n_outcomes].\
24
+ reshape(n_mean_features, n_outcomes)
25
+ coef_dispersion = params[n_mean_features * n_outcomes :].\
26
+ reshape(n_dispersion_features, n_outcomes)
27
+ r = torch.exp(X_dict["dispersion"] @ coef_dispersion)
28
+ mu = torch.exp(X_dict["mean"] @ coef_mean)
29
+
30
+ # compute the negative log likelihood
31
+ log_likelihood = (
32
+ torch.lgamma(y + r)
33
+ - torch.lgamma(r)
34
+ - torch.lgamma(y + 1)
35
+ + r * torch.log(r)
36
+ + y * torch.log(mu)
37
+ - (r + y) * torch.log(r + mu)
38
+ )
39
+
40
+ return -torch.sum(log_likelihood)
41
+
42
+
43
+ def negbin_initializer(x_dict, y, device):
44
+ n_mean_features = x_dict["mean"].shape[1]
45
+ n_outcomes = y.shape[1]
46
+ n_dispersion_features = x_dict["dispersion"].shape[1]
47
+ return torch.zeros(
48
+ n_mean_features * n_outcomes\
49
+ + n_dispersion_features * n_outcomes,
50
+ requires_grad=True, device=device
51
+ )
52
+
53
+
54
+ def negbin_postprocessor(params, x_dict, y):
55
+ n_mean_features = x_dict["mean"].shape[1]
56
+ n_outcomes = y.shape[1]
57
+ n_dispersion_features = x_dict["dispersion"].shape[1]
58
+ coef_mean = format.to_np(params[:n_mean_features * n_outcomes]).\
59
+ reshape(n_mean_features, n_outcomes)
60
+ coef_dispersion = format.to_np(params[n_mean_features * n_outcomes:]).\
61
+ reshape(n_dispersion_features, n_outcomes)
62
+ return {"coef_mean": coef_mean, "coef_dispersion": coef_dispersion}
63
+
64
+
65
+ negbin_regression_array = factory.multiple_formula_regression_factory(
66
+ negbin_regression_likelihood, negbin_initializer, negbin_postprocessor
67
+ )
68
+
69
+
70
+ ###############################################################################
71
+ ## Regression functions that operate on AnnData objects
72
+ ###############################################################################
73
+
74
+ def format_negbin_parameters(
75
+ parameters: dict, var_names: list, mean_coef_index: list,
76
+ dispersion_coef_index: list
77
+ ) -> dict:
78
+ parameters["coef_mean"] = pd.DataFrame(
79
+ parameters["coef_mean"], columns=var_names, index=mean_coef_index
80
+ )
81
+ parameters["coef_dispersion"] = pd.DataFrame(
82
+ parameters["coef_dispersion"], columns=var_names, index=dispersion_coef_index
83
+ )
84
+ return parameters
85
+
86
+ def format_negbin_parameters_with_loaders(
87
+ parameters: dict, var_names: list, dls: dict
88
+ ) -> dict:
89
+ # Extract the coefficient indices from the dataloaders
90
+ mean_coef_index = dls["mean"].dataset.x_names
91
+ dispersion_coef_index = dls["dispersion"].dataset.x_names
92
+
93
+ return format_negbin_parameters(parameters, var_names, mean_coef_index, dispersion_coef_index)
94
+
95
+ def negbin_regression(
96
+ adata: AnnData, formula: Union[str, dict], chunk_size: int = int(1e4), batch_size=512, **kwargs
97
+ ) -> dict:
98
+ formula = data.standardize_formula(formula, allowed_keys=['mean', 'dispersion'])
99
+
100
+ loaders = data.multiple_formula_loader(
101
+ adata, formula, chunk_size=chunk_size, batch_size=batch_size
102
+ )
103
+ parameters = negbin_regression_array(loaders, **kwargs)
104
+ return format_negbin_parameters(
105
+ parameters, list(adata.var_names), loaders["mean"].dataset.x_names, loaders["dispersion"].dataset.x_names
106
+ )
107
+
108
+ ###############################################################################
109
+ ## Copula versions for negative binomial regression
110
+ ###############################################################################
111
+
112
+
113
+ def negbin_uniformizer(parameters, X_dict, y, epsilon=1e-3):
114
+ r = np.exp(X_dict["dispersion"] @ parameters["coef_dispersion"])
115
+ mu = np.exp(X_dict["mean"] @ parameters["coef_mean"])
116
+ u1 = nbinom(n=r, p=r / (r + mu)).cdf(y)
117
+ u2 = np.where(y > 0, nbinom(n=r, p=r / (r + mu)).cdf(y - 1), 0)
118
+ v = np.random.uniform(size=y.shape)
119
+ return np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
120
+
121
+
122
+ negbin_copula_array = gcf.gaussian_copula_array_factory(
123
+ negbin_regression_array, negbin_uniformizer
124
+ ) # should accept a dictionary of dataloaders
125
+
126
+ negbin_copula = gcf.gaussian_copula_factory(
127
+ negbin_copula_array, format_negbin_parameters_with_loaders,
128
+ param_name=['mean', 'dispersion']
129
+ )
@@ -0,0 +1,160 @@
1
+ from scipy.sparse.linalg import svds
2
+ from sklearn.cluster import KMeans
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+
7
+
8
+ # computes PNMF weight and score, ncol specify the number of clusters
9
+ def pnmf(log_data, nbase=3, **kwargs): # data is np array, log transformed read data
10
+ """
11
+ Computes PNMF weight and score.
12
+
13
+ :log_data: log transformed np array of read data
14
+ :ncol: specify the number of clusters
15
+ :return: W (weights, gene x base) and S (scores, base x cell) as numpy arrays
16
+ """
17
+ U = left_singular(log_data, nbase)
18
+ W = pnmf_eucdist(log_data, U, **kwargs)
19
+ W = W / np.linalg.norm(W, ord=2)
20
+ S = W.T @ log_data
21
+ return W, S
22
+
23
+
24
+ def gamma_regression_array(
25
+ x: np.array, y: np.array, batch_size: int = 512, lr: float = 0.1, epochs: int = 40
26
+ ) -> dict:
27
+ x = torch.tensor(x, dtype=torch.float32)
28
+ y = torch.tensor(y, dtype=torch.float32)
29
+
30
+ n_features, n_outcomes = x.shape[1], y.shape[1]
31
+ a = torch.zeros(n_features * n_outcomes, requires_grad=True)
32
+ loc = torch.zeros(n_features * n_outcomes, requires_grad=True)
33
+ beta = torch.zeros(n_features * n_outcomes, requires_grad=True)
34
+ optimizer = torch.optim.Adam([a, loc, beta], lr=lr)
35
+
36
+ for i in range(epochs):
37
+ optimizer.zero_grad()
38
+ loss = negative_gamma_log_likelihood(a, beta, loc, x, y)
39
+ loss.backward()
40
+ optimizer.step()
41
+
42
+ a = to_np(a).reshape(n_features, n_outcomes)
43
+ loc = to_np(loc).reshape(n_features, n_outcomes)
44
+ beta = to_np(beta).reshape(n_features, n_outcomes)
45
+ return {"a": a, "loc": loc, "beta": beta}
46
+
47
+
48
+ def class_generator(score, n_clusters=3):
49
+ """
50
+ Generates one-hot encoding for score classes
51
+ """
52
+ kmeans = KMeans(n_clusters, random_state=0) # Specify the number of clusters
53
+ kmeans.fit(score.T)
54
+ labels = kmeans.labels_
55
+ num_classes = len(np.unique(labels))
56
+ one_hot = np.eye(num_classes)[labels].astype(int)
57
+ return labels
58
+
59
+
60
+ ###############################################################################
61
+ ## Helpers for deriving PNMF
62
+ ###############################################################################
63
+
64
+
65
+ def pnmf_eucdist(X, W_init, maxIter=500, threshold=1e-4, tol=1e-10, verbose=False):
66
+ # initialization
67
+ W = W_init # initial W is the PCA of X
68
+ XX = X @ X.T
69
+
70
+ # iterations
71
+ for iter in range(maxIter):
72
+ if verbose and (iter + 1) % 10 == 0:
73
+ print("%d iterations used." % (iter + 1))
74
+ W_old = W
75
+
76
+ XXW = XX @ W
77
+ SclFactor = np.dot(W, W.T @ XXW) + np.dot(XXW, W.T @ W)
78
+
79
+ # QuotientLB
80
+ SclFactor = MatFindlb(SclFactor, tol)
81
+ SclFactor = XXW / SclFactor
82
+ W = W * SclFactor # somehow W *= SclFactor doesn't work?
83
+
84
+ norm_W = np.linalg.norm(W)
85
+ W /= norm_W
86
+ W = MatFind(W, tol)
87
+
88
+ diffW = np.linalg.norm(W_old - W) / np.linalg.norm(W_old)
89
+ if diffW < threshold:
90
+ break
91
+
92
+ return W
93
+
94
+
95
+ # left singular vector of X
96
+ def left_singular(X, k):
97
+ U, _, _ = svds(X, k=k)
98
+ return np.abs(U)
99
+
100
+
101
+ def MatFindlb(A, lb):
102
+ B = np.ones(A.shape) * lb
103
+ Alb = np.where(A < lb, B, A)
104
+ return Alb
105
+
106
+
107
+ def MatFind(A, ZeroThres):
108
+ B = np.zeros(A.shape)
109
+ Atrunc = np.where(A < ZeroThres, B, A)
110
+ return Atrunc
111
+
112
+
113
+ ###############################################################################
114
+ ## Helpers for training PNMF regression
115
+ ###############################################################################
116
+
117
+
118
+ def shifted_gamma_pdf(x, alpha, beta, loc):
119
+ if not torch.is_tensor(x):
120
+ x = torch.tensor(x)
121
+ mask = x < loc
122
+ y_clamped = torch.clamp(x - loc, min=1e-12)
123
+
124
+ log_pdf = (
125
+ alpha * torch.log(beta)
126
+ - torch.lgamma(alpha)
127
+ + (alpha - 1) * torch.log(y_clamped)
128
+ - beta * y_clamped
129
+ )
130
+ loss = -torch.mean(log_pdf[~mask])
131
+ n_invalid = mask.sum()
132
+ if n_invalid > 0: # force samples to be greater than loc
133
+ loss = loss + 1e10 * n_invalid.float()
134
+ return loss
135
+
136
+
137
+ def negative_gamma_log_likelihood(log_a, log_beta, loc, X, y):
138
+ n_features = X.shape[1]
139
+ n_outcomes = y.shape[1]
140
+
141
+ a = torch.exp(log_a.reshape(n_features, n_outcomes))
142
+ beta = torch.exp(log_beta.reshape(n_features, n_outcomes))
143
+ loc = loc.reshape(n_features, n_outcomes)
144
+ return shifted_gamma_pdf(y, X @ a, X @ beta, X @ loc)
145
+
146
+
147
+ def to_np(x):
148
+ return x.detach().cpu().numpy()
149
+
150
+
151
+ def format_gamma_parameters(
152
+ parameters: dict,
153
+ W_index: list,
154
+ coef_index: list,
155
+ ) -> dict:
156
+ parameters["a"] = pd.DataFrame(parameters["a"], index=coef_index)
157
+ parameters["loc"] = pd.DataFrame(parameters["loc"], index=coef_index)
158
+ parameters["beta"] = pd.DataFrame(parameters["beta"], index=coef_index)
159
+ parameters["W"] = pd.DataFrame(parameters["W"], index=W_index)
160
+ return parameters
@@ -0,0 +1,100 @@
1
+ from . import gaussian_copula_factory as gcf
2
+ from . import glm_factory as factory
3
+ from .. import data
4
+ from .. import format
5
+ from anndata import AnnData
6
+ from scipy.stats import poisson
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+
11
+ ###############################################################################
12
+ ## Regression functions that operate on numpy arrays
13
+ ###############################################################################
14
+
15
+
16
+ def poisson_regression_likelihood(params, X, y, epsilon=1e-6):
17
+ # get appropriate parameter shape
18
+ n_features = X['mean'].shape[1]
19
+ n_outcomes = y.shape[1]
20
+
21
+ # compute the negative log likelihood
22
+ beta = params.reshape(n_features, n_outcomes)
23
+ mu = torch.exp(X['mean'] @ beta)
24
+ log_likelihood = y * torch.log(mu+epsilon) - mu - torch.lgamma(y + 1)
25
+ return -torch.sum(log_likelihood)
26
+
27
+
28
+ def poisson_initializer(x, y, device):
29
+ n_features, n_outcomes = x['mean'].shape[1], y.shape[1]
30
+ return torch.zeros(n_features * n_outcomes, requires_grad=True, device=device)
31
+
32
+
33
+ def poisson_postprocessor(params, x, y):
34
+ coef_mean = format.to_np(params).reshape(x['mean'].shape[1], y.shape[1])
35
+ return {"coef_mean": coef_mean}
36
+
37
+
38
+ poisson_regression_array = factory.multiple_formula_regression_factory(
39
+ poisson_regression_likelihood, poisson_initializer, poisson_postprocessor
40
+ )
41
+
42
+ ###############################################################################
43
+ ## Regression functions that operate on AnnData objects
44
+ ###############################################################################
45
+
46
+
47
+ def format_poisson_parameters(
48
+ parameters: dict, var_names: list, coef_index: list
49
+ ) -> dict:
50
+ parameters["coef_mean"] = pd.DataFrame(
51
+ parameters["coef_mean"], columns=var_names, index=coef_index
52
+ )
53
+ return parameters
54
+
55
+
56
+ def poisson_regression(
57
+ adata: AnnData,
58
+ formula: str,
59
+ chunk_size: int = int(1e4),
60
+ batch_size: int = 512,
61
+ **kwargs
62
+ ) -> dict:
63
+ formula = data.standardize_formula(formula, allowed_keys=['mean'])
64
+ loaders = data.multiple_formula_loader(
65
+ adata, formula, chunk_size=chunk_size, batch_size=batch_size
66
+ )
67
+
68
+ parameters = poisson_regression_array(loaders, **kwargs)
69
+ return format_poisson_parameters(
70
+ parameters, list(adata.var_names), loaders["mean"].dataset.x_names
71
+ )
72
+
73
+
74
+ ###############################################################################
75
+ ## Copula versions for poisson regression
76
+ ###############################################################################
77
+
78
+
79
+ def poisson_uniformizer(parameters, x, y, epsilon=1e-3):
80
+ mu = np.exp(x['mean'] @ parameters["coef_mean"])
81
+ u1 = poisson(mu).cdf(y)
82
+ u2 = np.where(y > 0, poisson(mu).cdf(y - 1), 0)
83
+ v = np.random.uniform(size=y.shape)
84
+ return np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
85
+
86
+ def format_poisson_parameters_with_loaders(parameters: dict, var_names: list, dls: dict) -> dict:
87
+ beta_coef_index = dls["mean"].dataset.x_names
88
+
89
+ parameters["coef_mean"] = pd.DataFrame(
90
+ parameters["coef_mean"], columns=var_names, index=beta_coef_index
91
+ )
92
+ return parameters
93
+
94
+ poisson_copula_array = gcf.gaussian_copula_array_factory(
95
+ poisson_regression_array, poisson_uniformizer
96
+ )
97
+
98
+ poisson_copula = gcf.gaussian_copula_factory(
99
+ poisson_copula_array, format_poisson_parameters_with_loaders, ['mean']
100
+ )
@@ -0,0 +1,195 @@
1
+ import warnings
2
+ from . import gaussian_copula_factory as gcf
3
+ from .. import format
4
+ from .. import data
5
+ from . import glm_factory as factory
6
+ from anndata import AnnData
7
+ from scipy.stats import nbinom
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ from typing import Union
12
+ from scipy.special import expit
13
+
14
+ ###############################################################################
15
+ ## Regression functions that operate on numpy arrays
16
+ ###############################################################################
17
+
18
+
19
+ def zero_inflated_negbin_regression_likelihood(params, X_dict, y):
20
+ # get appropriate parameter shape
21
+ n_mean_features = X_dict["mean"].shape[1]
22
+ n_dispersion_features = X_dict["dispersion"].shape[1]
23
+ n_zero_inflation_features = X_dict["zero_inflation"].shape[1]
24
+ n_outcomes = y.shape[1]
25
+
26
+ # define the likelihood parameters
27
+ coef_mean = params[: n_mean_features * n_outcomes].\
28
+ reshape(n_mean_features, n_outcomes)
29
+ coef_dispersion = params[n_mean_features * n_outcomes :\
30
+ n_mean_features * n_outcomes + n_dispersion_features * n_outcomes].\
31
+ reshape(n_dispersion_features, n_outcomes)
32
+ coef_zero_inflation = params[n_mean_features * n_outcomes + \
33
+ n_dispersion_features * n_outcomes :].\
34
+ reshape(n_zero_inflation_features, n_outcomes)
35
+
36
+ mu = torch.exp(X_dict["mean"] @ coef_mean)
37
+ r = torch.exp(X_dict["dispersion"] @ coef_dispersion)
38
+ pi = torch.sigmoid(X_dict["zero_inflation"] @ coef_zero_inflation)
39
+
40
+ # negative binomial component
41
+ negbin_loglikelihood = (
42
+ torch.lgamma(y + r)
43
+ - torch.lgamma(r)
44
+ - torch.lgamma(y + 1)
45
+ + r * torch.log(r)
46
+ + y * torch.log(mu)
47
+ - (r + y) * torch.log(r + mu)
48
+ )
49
+
50
+ # return the mixture, with an offset to prevent log(0)
51
+ log_likelihood = torch.log(
52
+ pi * (y == 0) + (1 - pi) * torch.exp(negbin_loglikelihood) + 1e-10
53
+ )
54
+ return -torch.sum(log_likelihood)
55
+
56
+
57
+ def zero_inflated_negbin_initializer(X_dict, y, device):
58
+ n_mean_features = X_dict["mean"].shape[1]
59
+ n_dispersion_features = X_dict["dispersion"].shape[1]
60
+ n_zero_inflation_features = X_dict["zero_inflation"].shape[1]
61
+ n_outcomes = y.shape[1]
62
+ return torch.zeros(
63
+ n_mean_features * n_outcomes + n_dispersion_features * n_outcomes + \
64
+ n_zero_inflation_features * n_outcomes, requires_grad=True, device=device
65
+ )
66
+
67
+
68
+ def zero_inflated_negbin_postprocessor(params, X_dict, y):
69
+ n_mean_features = X_dict["mean"].shape[1]
70
+ n_dispersion_features = X_dict["dispersion"].shape[1]
71
+ n_zero_inflation_features = X_dict["zero_inflation"].shape[1]
72
+ n_outcomes = y.shape[1]
73
+ coef_mean = format.to_np(params[:n_mean_features * n_outcomes]).\
74
+ reshape(n_mean_features, n_outcomes)
75
+ coef_dispersion = format.to_np(params[n_mean_features * n_outcomes\
76
+ : n_mean_features * n_outcomes + n_dispersion_features * n_outcomes]).\
77
+ reshape(n_dispersion_features, n_outcomes)
78
+ coef_zero_inflation = format.to_np(params[n_mean_features * n_outcomes \
79
+ + n_dispersion_features * n_outcomes :]).\
80
+ reshape(n_zero_inflation_features, n_outcomes)
81
+ return {"coef_mean": coef_mean, "coef_dispersion": coef_dispersion,\
82
+ "coef_zero_inflation": coef_zero_inflation}
83
+
84
+
85
+ zero_inflated_negbin_regression_array = factory.multiple_formula_regression_factory(
86
+ zero_inflated_negbin_regression_likelihood,
87
+ zero_inflated_negbin_initializer,
88
+ zero_inflated_negbin_postprocessor,
89
+ )
90
+
91
+ ###############################################################################
92
+ ## Regression functions that operate on AnnData objects
93
+ ###############################################################################
94
+
95
+
96
+ def format_zero_inflated_negbin_parameters(
97
+ parameters: dict, var_names: list, mean_coef_index:
98
+ list, dispersion_coef_index: list, zero_inflation_coef_index: list
99
+ ) -> dict:
100
+ parameters["coef_mean"] = pd.DataFrame(
101
+ parameters["coef_mean"], columns=var_names, index=mean_coef_index
102
+ )
103
+ parameters["coef_dispersion"] = pd.DataFrame(
104
+ parameters["coef_dispersion"], columns=var_names, index=dispersion_coef_index
105
+ )
106
+ parameters["coef_zero_inflation"] = pd.DataFrame(
107
+ parameters["coef_zero_inflation"], columns=var_names, index=zero_inflation_coef_index
108
+ )
109
+ return parameters
110
+
111
+ def format_zero_inflated_negbin_parameters_with_loaders(
112
+ parameters: dict, var_names: list, dls: dict
113
+ ) -> dict:
114
+ mean_coef_index = dls["mean"].dataset.x_names
115
+ dispersion_coef_index = dls["dispersion"].dataset.x_names
116
+ zero_inflation_coef_index = dls["zero_inflation"].dataset.x_names
117
+ return format_zero_inflated_negbin_parameters(
118
+ parameters, var_names, mean_coef_index, dispersion_coef_index, zero_inflation_coef_index
119
+ )
120
+
121
+ def standardize_zero_inflated_negbin_formula(formula: Union[str, dict]) -> dict:
122
+ '''
123
+ Convert string formula to dict and validate type.
124
+ If formula is a string, it is the formula for the mean parameter.
125
+ If formula is a dictionary, it is a dictionary of formulas for the mean, dispersion, and zero_inflation parameters.
126
+ '''
127
+ # Convert string formula to dict and validate type
128
+ formula = {'mean': formula, 'dispersion': '~ 1', 'zero_inflation': '~ 1'} \
129
+ if isinstance(formula, str) else formula
130
+ if not isinstance(formula, dict):
131
+ raise ValueError("formula must be a string or a dictionary")
132
+
133
+ # Define allowed keys and set defaults
134
+ allowed_keys = {'mean', 'dispersion', 'zero_inflation'}
135
+ formula_keys = set(formula.keys())
136
+
137
+ # check for required keys and warn about extras
138
+ if not formula_keys & allowed_keys:
139
+ raise ValueError("formula must have at least one of \
140
+ the following keys: mean, dispersion, zero_inflation")
141
+
142
+ # warn about unused keys
143
+ if extra_keys := formula_keys - allowed_keys:
144
+ warnings.warn(
145
+ f"Invalid formulas in dictionary for zero-inflated \
146
+ negative binomial regression: {extra_keys}",
147
+ UserWarning,
148
+ )
149
+
150
+ # set default values for missing keys
151
+ formula.update({k: '~ 1' for k in allowed_keys - formula_keys})
152
+ return formula
153
+
154
+
155
+ def zero_inflated_negbin_regression(
156
+ adata: AnnData, formula: Union[str, dict], chunk_size: int = int(1e4), batch_size=512, **kwargs
157
+ ) -> dict:
158
+ formula = data.standardize_formula(formula, allowed_keys=['mean', 'dispersion', 'zero_inflation'])
159
+
160
+ loaders = data.multiple_formula_loader(
161
+ adata, formula, chunk_size=chunk_size, batch_size=batch_size
162
+ )
163
+
164
+ parameters = zero_inflated_negbin_regression_array(loaders, **kwargs)
165
+ return format_zero_inflated_negbin_parameters(
166
+ parameters, list(adata.var_names), loaders["mean"].dataset.x_names,
167
+ loaders["dispersion"].dataset.x_names, loaders["zero_inflation"].dataset.x_names
168
+ )
169
+
170
+
171
+ ###############################################################################
172
+ ## Copula versions for ZINB regression
173
+ ###############################################################################
174
+
175
+
176
+ def zero_inflated_negbin_uniformizer(parameters, X_dict, y, epsilon=1e-3):
177
+ r, mu, pi = (
178
+ np.exp(X_dict["dispersion"] @ parameters["coef_dispersion"]),
179
+ np.exp(X_dict["mean"] @ parameters["coef_mean"]),
180
+ expit(X_dict["zero_inflation"] @ parameters["coef_zero_inflation"]),
181
+ )
182
+ nb_distn = nbinom(n=r, p=r / (r + mu))
183
+ u1 = pi + (1 - pi) * nb_distn.cdf(y)
184
+ u2 = np.where(y > 0, pi + (1 - pi) * nb_distn.cdf(y-1), 0)
185
+ v = np.random.uniform(size=y.shape)
186
+ return np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
187
+
188
+
189
+ zero_inflated_negbin_copula = gcf.gaussian_copula_factory(
190
+ gcf.gaussian_copula_array_factory(
191
+ zero_inflated_negbin_regression_array, zero_inflated_negbin_uniformizer
192
+ ),
193
+ format_zero_inflated_negbin_parameters_with_loaders,
194
+ param_name=['mean', 'dispersion', 'zero_inflation'],
195
+ )
@@ -0,0 +1,85 @@
1
+ from anndata import AnnData
2
+ from .. import format
3
+ from .. import data
4
+ from . import glm_factory as factory
5
+ import pandas as pd
6
+ import torch
7
+
8
+
9
+ def zero_inflated_poisson_regression_likelihood(params, X, y):
10
+ # get appropriate parameter shape
11
+ mean_n_features = X['mean'].shape[1]
12
+ zero_inflation_n_features = X['zero_inflation'].shape[1]
13
+ n_outcomes = y.shape[1]
14
+
15
+ # define the likelihood parameters
16
+ b_elem = mean_n_features * n_outcomes
17
+ coef_mean = params[:b_elem].reshape(mean_n_features, n_outcomes)
18
+ coef_zero_inflation = params[b_elem:].reshape(zero_inflation_n_features, n_outcomes)
19
+
20
+ zero_inflation = torch.sigmoid(X['zero_inflation'] @ coef_zero_inflation)
21
+ mu = torch.exp(X['mean'] @ coef_mean)
22
+ poisson_loglikelihood = y * torch.log(mu + 1e-10) - mu - torch.lgamma(y + 1)
23
+
24
+ # return the mixture, with an offset to prevent log(0)
25
+ log_likelihood = torch.log(
26
+ zero_inflation * (y == 0) + (1 - zero_inflation) * torch.exp(poisson_loglikelihood) + 1e-10
27
+ )
28
+ return -torch.sum(log_likelihood)
29
+
30
+
31
+ def zero_inflated_poisson_initializer(x, y, device):
32
+ mean_n_features = x['mean'].shape[1]
33
+ zero_inflation_n_features = x['zero_inflation'].shape[1]
34
+ n_outcomes = y.shape[1]
35
+ return torch.zeros(
36
+ mean_n_features * n_outcomes + zero_inflation_n_features * n_outcomes, requires_grad=True, device=device
37
+ )
38
+
39
+
40
+ def zero_inflated_poisson_postprocessor(params, x, y):
41
+ mean_n_features = x['mean'].shape[1]
42
+ zero_inflation_n_features = x['zero_inflation'].shape[1]
43
+ n_outcomes = y.shape[1]
44
+ b_elem = mean_n_features * n_outcomes
45
+ coef_mean = format.to_np(params[:b_elem]).reshape(mean_n_features, n_outcomes)
46
+ coef_zero_inflation = format.to_np(params[b_elem:]).reshape(zero_inflation_n_features, n_outcomes)
47
+ return {"coef_mean": coef_mean, "coef_zero_inflation": coef_zero_inflation}
48
+
49
+
50
+ zero_inflated_poisson_regression_array = factory.multiple_formula_regression_factory(
51
+ zero_inflated_poisson_regression_likelihood,
52
+ zero_inflated_poisson_initializer,
53
+ zero_inflated_poisson_postprocessor,
54
+ )
55
+
56
+
57
+ ###############################################################################
58
+ ## Regression functions that operate on AnnData objects
59
+ ###############################################################################
60
+
61
+
62
+ def format_zero_inflated_poisson_parameters(
63
+ parameters: dict, var_names: list, mean_coef_index: list,
64
+ zero_inflation_coef_index: list
65
+ ) -> dict:
66
+ parameters["coef_mean"] = pd.DataFrame(
67
+ parameters["coef_mean"], columns=var_names, index=mean_coef_index
68
+ )
69
+ parameters["coef_zero_inflation"] = pd.DataFrame(
70
+ parameters["coef_zero_inflation"], columns=var_names, index=zero_inflation_coef_index
71
+ )
72
+ return parameters
73
+
74
+
75
+ def zero_inflated_poisson_regression(
76
+ adata: AnnData, formula: str, chunk_size: int = int(1e4), batch_size=512, **kwargs
77
+ ) -> dict:
78
+ formula = data.standardize_formula(formula, allowed_keys=['mean', 'zero_inflation'])
79
+ loaders = data.multiple_formula_loader(
80
+ adata, formula, chunk_size=chunk_size, batch_size=batch_size
81
+ )
82
+ parameters = zero_inflated_poisson_regression_array(loaders, **kwargs)
83
+ return format_zero_inflated_poisson_parameters(
84
+ parameters, list(adata.var_names), loaders["mean"].dataset.x_names, loaders["zero_inflation"].dataset.x_names
85
+ )
@@ -0,0 +1,4 @@
1
+ from .format import to_np, format_input_anndata, format_matrix
2
+ from .print import print_simulator
3
+
4
+ __all__ = ["to_np", "format_input_anndata", "format_matrix", "print_simulator"]
@@ -0,0 +1,20 @@
1
+ from anndata import AnnData
2
+ from formulaic import model_matrix
3
+ import pandas as pd
4
+ import scipy.sparse
5
+
6
+ def to_np(x):
7
+ return x.detach().cpu().numpy()
8
+
9
+ def format_input_anndata(adata: AnnData) -> AnnData:
10
+ result = adata.copy()
11
+ if isinstance(result.X, scipy.sparse._csc.csc_matrix):
12
+ result.X = result.X.todense()
13
+ return result
14
+
15
+ def format_matrix(obs: pd.DataFrame, formula: str):
16
+ if formula is not None:
17
+ x = model_matrix(formula, pd.DataFrame(obs))
18
+ else:
19
+ x = obs
20
+ return x