scdesigner 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scdesigner might be problematic. Click here for more details.

Files changed (66) hide show
  1. scdesigner/__init__.py +0 -0
  2. scdesigner/data/__init__.py +16 -0
  3. scdesigner/data/formula.py +137 -0
  4. scdesigner/data/group.py +123 -0
  5. scdesigner/data/sparse.py +39 -0
  6. scdesigner/diagnose/__init__.py +65 -0
  7. scdesigner/diagnose/aic_bic.py +119 -0
  8. scdesigner/diagnose/plot.py +242 -0
  9. scdesigner/estimators/__init__.py +27 -0
  10. scdesigner/estimators/bernoulli.py +85 -0
  11. scdesigner/estimators/gaussian.py +121 -0
  12. scdesigner/estimators/gaussian_copula_factory.py +152 -0
  13. scdesigner/estimators/glm_factory.py +75 -0
  14. scdesigner/estimators/negbin.py +129 -0
  15. scdesigner/estimators/pnmf.py +160 -0
  16. scdesigner/estimators/poisson.py +100 -0
  17. scdesigner/estimators/zero_inflated_negbin.py +195 -0
  18. scdesigner/estimators/zero_inflated_poisson.py +85 -0
  19. scdesigner/format/__init__.py +4 -0
  20. scdesigner/format/format.py +20 -0
  21. scdesigner/format/print.py +30 -0
  22. scdesigner/minimal/__init__.py +17 -0
  23. scdesigner/minimal/bernoulli.py +61 -0
  24. scdesigner/minimal/composite.py +119 -0
  25. scdesigner/minimal/copula.py +33 -0
  26. scdesigner/minimal/formula.py +23 -0
  27. scdesigner/minimal/gaussian.py +65 -0
  28. scdesigner/minimal/kwargs.py +24 -0
  29. scdesigner/minimal/loader.py +166 -0
  30. scdesigner/minimal/marginal.py +140 -0
  31. scdesigner/minimal/negbin.py +73 -0
  32. scdesigner/minimal/positive_nonnegative_matrix_factorization.py +231 -0
  33. scdesigner/minimal/scd3.py +95 -0
  34. scdesigner/minimal/scd3_instances.py +50 -0
  35. scdesigner/minimal/simulator.py +25 -0
  36. scdesigner/minimal/standard_covariance.py +124 -0
  37. scdesigner/minimal/transform.py +145 -0
  38. scdesigner/minimal/zero_inflated_negbin.py +86 -0
  39. scdesigner/predictors/__init__.py +15 -0
  40. scdesigner/predictors/bernoulli.py +9 -0
  41. scdesigner/predictors/gaussian.py +16 -0
  42. scdesigner/predictors/negbin.py +17 -0
  43. scdesigner/predictors/poisson.py +12 -0
  44. scdesigner/predictors/zero_inflated_negbin.py +18 -0
  45. scdesigner/predictors/zero_inflated_poisson.py +18 -0
  46. scdesigner/samplers/__init__.py +23 -0
  47. scdesigner/samplers/bernoulli.py +27 -0
  48. scdesigner/samplers/gaussian.py +25 -0
  49. scdesigner/samplers/glm_factory.py +41 -0
  50. scdesigner/samplers/negbin.py +25 -0
  51. scdesigner/samplers/poisson.py +25 -0
  52. scdesigner/samplers/zero_inflated_negbin.py +40 -0
  53. scdesigner/samplers/zero_inflated_poisson.py +16 -0
  54. scdesigner/simulators/__init__.py +31 -0
  55. scdesigner/simulators/composite_regressor.py +72 -0
  56. scdesigner/simulators/glm_simulator.py +167 -0
  57. scdesigner/simulators/pnmf_regression.py +61 -0
  58. scdesigner/transform/__init__.py +7 -0
  59. scdesigner/transform/amplify.py +14 -0
  60. scdesigner/transform/mask.py +33 -0
  61. scdesigner/transform/nullify.py +25 -0
  62. scdesigner/transform/split.py +23 -0
  63. scdesigner/transform/substitute.py +14 -0
  64. scdesigner-0.0.1.dist-info/METADATA +23 -0
  65. scdesigner-0.0.1.dist-info/RECORD +66 -0
  66. scdesigner-0.0.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,73 @@
1
+ from .formula import standardize_formula
2
+ from .marginal import GLMPredictor, Marginal
3
+ from .loader import _to_numpy
4
+ from typing import Union, Dict, Optional
5
+ import torch
6
+ import numpy as np
7
+ from scipy.stats import nbinom
8
+
9
+ class NegBin(Marginal):
10
+ """Negative-binomial marginal estimator"""
11
+ def __init__(self, formula: Union[Dict, str]):
12
+ formula = standardize_formula(formula, allowed_keys=['mean', 'dispersion'])
13
+ super().__init__(formula)
14
+
15
+ def setup_optimizer(
16
+ self,
17
+ optimizer_class: Optional[callable] = torch.optim.Adam,
18
+ **optimizer_kwargs,
19
+ ):
20
+ if self.loader is None:
21
+ raise RuntimeError("self.loader is not set (call setup_data first)")
22
+
23
+ nll = lambda batch: -self.likelihood(batch).sum()
24
+ self.predict = GLMPredictor(
25
+ n_outcomes=self.n_outcomes,
26
+ feature_dims=self.feature_dims,
27
+ loss_fn=nll,
28
+ optimizer_class=optimizer_class,
29
+ optimizer_kwargs=optimizer_kwargs
30
+ )
31
+
32
+ def likelihood(self, batch):
33
+ """Compute the negative log-likelihood"""
34
+ y, x = batch
35
+ params = self.predict(x)
36
+ mu = params.get('mean')
37
+ r = params.get('dispersion')
38
+ return (
39
+ torch.lgamma(y + r)
40
+ - torch.lgamma(r)
41
+ - torch.lgamma(y + 1.0)
42
+ + r * torch.log(r)
43
+ + y * torch.log(mu)
44
+ - (r + y) * torch.log(r + mu)
45
+ )
46
+
47
+ def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
48
+ """Invert pseudoobservations."""
49
+ mu, r, u = self._local_params(x, u)
50
+ p = r / (r + mu)
51
+ y = nbinom(n=r, p=p).ppf(u)
52
+ return torch.from_numpy(y).float()
53
+
54
+ def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6):
55
+ """Return uniformized pseudo-observations for counts y given covariates x."""
56
+ # cdf values using scipy's parameterization
57
+ mu, r, y = self._local_params(x, y)
58
+ p = r / (r + mu)
59
+ u1 = nbinom(n=r, p=p).cdf(y)
60
+ u2 = np.where(y > 0, nbinom(n=r, p=p).cdf(y - 1), 0.0)
61
+
62
+ # randomize within discrete mass to get uniform(0,1)
63
+ v = np.random.uniform(size=y.shape)
64
+ u = np.clip(v * u1 + (1.0 - v) * u2, epsilon, 1.0 - epsilon)
65
+ return torch.from_numpy(u).float()
66
+
67
+ def _local_params(self, x, y=None):
68
+ params = self.predict(x)
69
+ mu = params.get('mean')
70
+ r = params.get('dispersion')
71
+ if y is None:
72
+ return _to_numpy(mu, r)
73
+ return _to_numpy(mu, r, y)
@@ -0,0 +1,231 @@
1
+ from .formula import standardize_formula
2
+ from .loader import _to_numpy
3
+ from .simulator import Simulator
4
+ from anndata import AnnData
5
+ from formulaic import model_matrix
6
+ from scipy.stats import gamma
7
+ from typing import Union, Dict
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+
12
+ ################################################################################
13
+ ## Functions for estimating PNMF regression
14
+ ################################################################################
15
+
16
+ # computes PNMF weight and score, ncol specify the number of clusters
17
+ def pnmf(log_data, nbase=3, **kwargs): # data is np array, log transformed read data
18
+ """
19
+ Computes PNMF weight and score.
20
+
21
+ :log_data: log transformed np array of read data
22
+ :ncol: specify the number of clusters
23
+ :return: W (weights, gene x base) and S (scores, base x cell) as numpy arrays
24
+ """
25
+ U = left_singular(log_data, nbase)
26
+ W = pnmf_eucdist(log_data, U, **kwargs)
27
+ W = W / np.linalg.norm(W, ord=2)
28
+ S = W.T @ log_data
29
+ return W, S
30
+
31
+
32
+ def gamma_regression_array(
33
+ x: np.array, y: np.array, lr: float = 0.1, epochs: int = 40
34
+ ) -> dict:
35
+ x = torch.tensor(x, dtype=torch.float32)
36
+ y = torch.tensor(y, dtype=torch.float32)
37
+
38
+ n_features, n_outcomes = x.shape[1], y.shape[1]
39
+ a = torch.zeros(n_features * n_outcomes, requires_grad=True)
40
+ loc = torch.zeros(n_features * n_outcomes, requires_grad=True)
41
+ beta = torch.zeros(n_features * n_outcomes, requires_grad=True)
42
+ optimizer = torch.optim.Adam([a, loc, beta], lr=lr)
43
+
44
+ for i in range(epochs):
45
+ optimizer.zero_grad()
46
+ loss = negative_gamma_log_likelihood(a, beta, loc, x, y)
47
+ loss.backward()
48
+ optimizer.step()
49
+
50
+ a, loc, beta = _to_numpy(a, loc, beta)
51
+ a = a.reshape(n_features, n_outcomes)
52
+ loc = loc.reshape(n_features, n_outcomes)
53
+ beta = beta.reshape(n_features, n_outcomes)
54
+ return {"a": a, "loc": loc, "beta": beta}
55
+
56
+
57
+ def class_generator(score, n_clusters=3):
58
+ """
59
+ Generates one-hot encoding for score classes
60
+ """
61
+ from sklearn.cluster import KMeans
62
+ kmeans = KMeans(n_clusters, random_state=0) # Specify the number of clusters
63
+ kmeans.fit(score.T)
64
+ labels = kmeans.labels_
65
+ num_classes = len(np.unique(labels))
66
+ one_hot = np.eye(num_classes)[labels].astype(int)
67
+ return labels
68
+
69
+
70
+ ###############################################################################
71
+ ## Helpers for deriving PNMF
72
+ ###############################################################################
73
+
74
+
75
+ def pnmf_eucdist(X, W_init, maxIter=500, threshold=1e-4, tol=1e-10, verbose=False, **kwargs):
76
+ # initialization
77
+ W = W_init # initial W is the PCA of X
78
+ XX = X @ X.T
79
+
80
+ # iterations
81
+ for iter in range(maxIter):
82
+ if verbose and (iter + 1) % 10 == 0:
83
+ print("%d iterations used." % (iter + 1))
84
+ W_old = W
85
+
86
+ XXW = XX @ W
87
+ SclFactor = np.dot(W, W.T @ XXW) + np.dot(XXW, W.T @ W)
88
+
89
+ # QuotientLB
90
+ SclFactor = MatFindlb(SclFactor, tol)
91
+ SclFactor = XXW / SclFactor
92
+ W = W * SclFactor # somehow W *= SclFactor doesn't work?
93
+
94
+ norm_W = np.linalg.norm(W)
95
+ W /= norm_W
96
+ W = MatFind(W, tol)
97
+
98
+ diffW = np.linalg.norm(W_old - W) / np.linalg.norm(W_old)
99
+ if diffW < threshold:
100
+ break
101
+
102
+ return W
103
+
104
+
105
+ # left singular vector of X
106
+ def left_singular(X, k):
107
+ from scipy.sparse.linalg import svds
108
+ U, _, _ = svds(X, k=k)
109
+ return np.abs(U)
110
+
111
+
112
+ def MatFindlb(A, lb):
113
+ B = np.ones(A.shape) * lb
114
+ Alb = np.where(A < lb, B, A)
115
+ return Alb
116
+
117
+
118
+ def MatFind(A, ZeroThres):
119
+ B = np.zeros(A.shape)
120
+ Atrunc = np.where(A < ZeroThres, B, A)
121
+ return Atrunc
122
+
123
+
124
+ ###############################################################################
125
+ ## Helpers for training PNMF regression
126
+ ###############################################################################
127
+
128
+
129
+ def shifted_gamma_pdf(x, alpha, beta, loc):
130
+ if not torch.is_tensor(x):
131
+ x = torch.tensor(x)
132
+ mask = x < loc
133
+ y_clamped = torch.clamp(x - loc, min=1e-12)
134
+
135
+ log_pdf = (
136
+ alpha * torch.log(beta)
137
+ - torch.lgamma(alpha)
138
+ + (alpha - 1) * torch.log(y_clamped)
139
+ - beta * y_clamped
140
+ )
141
+ loss = -torch.mean(log_pdf[~mask])
142
+ n_invalid = mask.sum()
143
+ if n_invalid > 0: # force samples to be greater than loc
144
+ loss = loss + 1e10 * n_invalid.float()
145
+ return loss
146
+
147
+
148
+ def negative_gamma_log_likelihood(log_a, log_beta, loc, X, y):
149
+ n_features = X.shape[1]
150
+ n_outcomes = y.shape[1]
151
+
152
+ a = torch.exp(log_a.reshape(n_features, n_outcomes))
153
+ beta = torch.exp(log_beta.reshape(n_features, n_outcomes))
154
+ loc = loc.reshape(n_features, n_outcomes)
155
+ return shifted_gamma_pdf(y, X @ a, X @ beta, X @ loc)
156
+
157
+ def format_gamma_parameters(
158
+ parameters: dict,
159
+ W_index: list,
160
+ coef_index: list,
161
+ ) -> dict:
162
+ parameters["a"] = pd.DataFrame(parameters["a"], index=coef_index)
163
+ parameters["loc"] = pd.DataFrame(parameters["loc"], index=coef_index)
164
+ parameters["beta"] = pd.DataFrame(parameters["beta"], index=coef_index)
165
+ parameters["W"] = pd.DataFrame(parameters["W"], index=W_index)
166
+ return parameters
167
+
168
+
169
+ ################################################################################
170
+ ## Associated PNMF Objects
171
+ ################################################################################
172
+
173
+ class PositiveNMF(Simulator):
174
+ """Positive nonnegative matrix factorization marginal estimator"""
175
+ def __init__(self, formula: Union[Dict, str], **kwargs):
176
+ self.formula = standardize_formula(formula, allowed_keys=['mean'])
177
+ self.parameters = None
178
+ self.hyperparams = kwargs
179
+
180
+
181
+ def setup_data(self, adata: AnnData, **kwargs):
182
+ self.log_data = np.log1p(adata.X).T
183
+ self.n_outcomes = self.log_data.shape[1]
184
+ self.template = adata
185
+ self.x = model_matrix(self.formula["mean"], adata.obs)
186
+ self.columns = self.x.columns
187
+ self.x = np.asarray(self.x)
188
+
189
+
190
+ def fit(self, adata: AnnData, lr: float=0.1):
191
+ self.setup_data(adata)
192
+ W, S = pnmf(self.log_data, **self.hyperparams)
193
+ parameters = gamma_regression_array(self.x, S.T, lr)
194
+ parameters["W"] = W
195
+ self.parameters = format_gamma_parameters(
196
+ parameters, list(self.template.var_names), list(self.columns)
197
+ )
198
+
199
+
200
+ def predict(self, obs=None, **kwargs):
201
+ """Predict from an obs dataframe"""
202
+ if obs is None:
203
+ obs = self.template.obs
204
+
205
+ x = model_matrix(self.formula["mean"], obs)
206
+ a, loc, beta = (
207
+ x @ np.exp(self.parameters["a"]),
208
+ x @ self.parameters["loc"],
209
+ x @ np.exp(self.parameters["beta"]),
210
+ )
211
+ return {"a": a, "loc": loc, "beta": beta}
212
+
213
+
214
+ def sample(self, obs=None):
215
+ """Generate samples."""
216
+ if obs is None:
217
+ obs = self.template.obs
218
+ W = self.parameters["W"]
219
+ parameters = self.predict(obs)
220
+ a, loc, beta = parameters["a"], parameters["loc"], parameters["beta"]
221
+ sim_score = gamma(a, loc, 1 / beta).rvs()
222
+ samples = np.exp(W @ sim_score.T).T
223
+
224
+ # thresholding samples
225
+ floor = np.floor(samples)
226
+ samples = floor + np.where(samples - floor < 0.9, 0, 1) - 1
227
+ samples = np.where(samples < 0, 0, samples)
228
+
229
+ result = AnnData(X=samples, obs=obs)
230
+ result.var_names = self.template.var_names
231
+ return result
@@ -0,0 +1,95 @@
1
+ from .copula import Copula
2
+ from .loader import obs_loader, adata_loader
3
+ from .marginal import Marginal
4
+ from .simulator import Simulator
5
+ from anndata import AnnData
6
+ from tqdm import tqdm
7
+ import torch
8
+ import numpy as np
9
+
10
+ class SCD3Simulator(Simulator):
11
+ """Simulation wrapper"""
12
+
13
+ def __init__(self, marginal: Marginal, copula: Copula):
14
+ self.marginal = marginal
15
+ self.copula = copula
16
+ self.template = None
17
+ self.parameters = None
18
+
19
+ def fit(
20
+ self,
21
+ adata: AnnData,
22
+ **kwargs):
23
+ """Fit the simulator"""
24
+ self.template = adata
25
+ self.marginal.setup_data(adata, **kwargs)
26
+ self.marginal.setup_optimizer(**kwargs)
27
+ self.marginal.fit(**kwargs)
28
+
29
+ # copula simulator
30
+ self.copula.setup_data(adata, self.marginal.formula, **kwargs)
31
+ self.copula.fit(self.marginal.uniformize, **kwargs)
32
+ self.parameters = {
33
+ "marginal": self.marginal.parameters,
34
+ "copula": self.copula.parameters
35
+ }
36
+
37
+ def predict(self, obs=None, batch_size: int = 1000, **kwargs):
38
+ """Predict from an obs dataframe"""
39
+ # prepare an internal data loader for this obs
40
+ if obs is None:
41
+ obs = self.template.obs
42
+ loader = obs_loader(
43
+ obs,
44
+ self.marginal.formula,
45
+ batch_size=batch_size,
46
+ **kwargs
47
+ )
48
+
49
+ # get predictions across batches
50
+ local_parameters = []
51
+ for _, x_dict in loader:
52
+ l = self.marginal.predict(x_dict)
53
+ local_parameters.append(l)
54
+
55
+ # convert to a merged dictionary
56
+ keys = list(local_parameters[0].keys())
57
+ return {
58
+ k: torch.cat([d[k] for d in local_parameters]).detach().cpu().numpy()
59
+ for k in keys
60
+ }
61
+
62
+ def sample(self, obs=None, batch_size: int = 1000, **kwargs):
63
+ """Generate samples."""
64
+ if obs is None:
65
+ obs = self.template.obs
66
+ loader = obs_loader(
67
+ obs,
68
+ self.copula.formula | self.marginal.formula,
69
+ batch_size=batch_size,
70
+ **kwargs
71
+ )
72
+
73
+ # get samples across batches
74
+ samples = []
75
+ for _, x_dict in loader:
76
+ u = self.copula.pseudo_obs(x_dict)
77
+ u = torch.from_numpy(u)
78
+ samples.append(self.marginal.invert(u, x_dict))
79
+ samples = torch.cat(samples).detach().cpu().numpy()
80
+ return AnnData(X = samples, obs=obs)
81
+
82
+ def complexity(self, adata: AnnData = None, **kwargs):
83
+ if adata is None:
84
+ adata = self.template
85
+
86
+ N, ll = 0, 0
87
+ loader = adata_loader(adata, self.marginal.formula | self.copula.formula, **kwargs)
88
+ for batch in tqdm(loader, desc="Computing log-likelihood..."):
89
+ ll += self.copula.likelihood(self.marginal.uniformize, batch).sum()
90
+ N += len(batch[0])
91
+
92
+ return {
93
+ "aic": -2 * ll + 2 * self.copula.num_params(),
94
+ "bic": -2 * ll + np.log(N) * self.copula.num_params()
95
+ }
@@ -0,0 +1,50 @@
1
+ from .scd3 import SCD3Simulator
2
+ from .negbin import NegBin
3
+ from .zero_inflated_negbin import ZeroInflatedNegBin
4
+ from .gaussian import Gaussian
5
+ from .standard_covariance import StandardCovariance
6
+ from typing import Optional
7
+
8
+
9
+ class NegBinCopula(SCD3Simulator):
10
+ def __init__(self,
11
+ mean_formula: Optional[str] = None,
12
+ dispersion_formula: Optional[str] = None,
13
+ copula_formula: Optional[str] = None) -> None:
14
+ marginal = NegBin({"mean": mean_formula, "dispersion": dispersion_formula})
15
+ covariance = StandardCovariance(copula_formula)
16
+ super().__init__(marginal, covariance)
17
+
18
+
19
+ class ZeroInflatedNegBinCopula(SCD3Simulator):
20
+ def __init__(self,
21
+ mean_formula: Optional[str] = None,
22
+ dispersion_formula: Optional[str] = None,
23
+ zero_inflation_formula: Optional[str] = None,
24
+ copula_formula: Optional[str] = None) -> None:
25
+ marginal = ZeroInflatedNegBin({
26
+ "mean": mean_formula,
27
+ "dispersion": dispersion_formula,
28
+ "zero_inflation_formula": zero_inflation_formula
29
+ })
30
+ covariance = StandardCovariance(copula_formula)
31
+ super().__init__(marginal, covariance)
32
+
33
+
34
+ class BernoulliCopula(SCD3Simulator):
35
+ def __init__(self,
36
+ mean_formula: Optional[str] = None,
37
+ copula_formula: Optional[str] = None) -> None:
38
+ marginal = NegBin({"mean": mean_formula})
39
+ covariance = StandardCovariance(copula_formula)
40
+ super().__init__(marginal, covariance)
41
+
42
+
43
+ class GaussianCopula(SCD3Simulator):
44
+ def __init__(self,
45
+ mean_formula: Optional[str] = None,
46
+ sdev_formula: Optional[str] = None,
47
+ copula_formula: Optional[str] = None) -> None:
48
+ marginal = Gaussian({"mean": mean_formula, "sdev": sdev_formula})
49
+ covariance = StandardCovariance(copula_formula)
50
+ super().__init__(marginal, covariance)
@@ -0,0 +1,25 @@
1
+ from anndata import AnnData
2
+ from typing import Dict
3
+ from pandas import DataFrame
4
+ from abc import abstractmethod
5
+
6
+ class Simulator:
7
+ """Simulation abstract class"""
8
+
9
+ def __init__(self):
10
+ self.parameters = None
11
+
12
+ @abstractmethod
13
+ def fit(self, anndata: AnnData, **kwargs) -> None:
14
+ """Fit the simulator"""
15
+ self.template = anndata
16
+
17
+ @abstractmethod
18
+ def predict(self, obs: DataFrame=None, **kwargs) -> Dict:
19
+ """Predict from an obs dataframe"""
20
+ pass
21
+
22
+ @abstractmethod
23
+ def sample(self, obs: DataFrame=None, **kwargs) -> AnnData:
24
+ """Generate samples."""
25
+ pass
@@ -0,0 +1,124 @@
1
+ from .copula import Copula
2
+ from .formula import standardize_formula
3
+ from .kwargs import DEFAULT_ALLOWED_KWARGS, _filter_kwargs
4
+ from anndata import AnnData
5
+ from scipy.stats import norm, multivariate_normal
6
+ from tqdm import tqdm
7
+ from typing import Dict, Union, Callable, Tuple
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+
12
+
13
+ class StandardCovariance(Copula):
14
+ def __init__(self, formula: str = "~ 1"):
15
+ formula = standardize_formula(formula, allowed_keys=['group'])
16
+ super().__init__(formula)
17
+ self.groups = None
18
+
19
+
20
+ def setup_data(self, adata: AnnData, marginal_formula: Dict[str, str], **kwargs):
21
+ data_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['data'])
22
+ super().setup_data(adata, marginal_formula, **data_kwargs)
23
+ _, obs_batch = next(iter(self.loader))
24
+ obs_batch_group = obs_batch.get("group")
25
+
26
+ # fill in group indexing variables
27
+ self.groups = self.loader.dataset.predictor_names["group"]
28
+ self.n_groups = len(self.groups)
29
+ self.group_col = {g: i for i, g in enumerate(self.groups)}
30
+
31
+ # check that obs_batch is a binary grouping matrix
32
+ unique_vals = torch.unique(obs_batch_group)
33
+ if (not torch.all((unique_vals == 0) | (unique_vals == 1)).item()):
34
+ raise ValueError("Only categorical groups are currently supported in copula covariance estimation.")
35
+
36
+ def fit(self, uniformizer: Callable, **kwargs):
37
+ sums = {g: np.zeros(self.n_outcomes) for g in self.groups}
38
+ second_moments = {g: np.eye(self.n_outcomes) for g in self.groups}
39
+ Ng = {g: 0 for g in self.groups}
40
+
41
+ for y, x_dict in tqdm(self.loader, desc="Estimating copula covariance"):
42
+ memberships = x_dict.get("group").numpy()
43
+ u = uniformizer(y, x_dict)
44
+
45
+ for g in self.groups:
46
+ ix = np.where(memberships[:, self.group_col[g]] == 1)
47
+ z = norm().ppf(u[ix])
48
+ second_moments[g] += z.T @ z
49
+ sums[g] += z.sum(axis=0)
50
+ Ng[g] += len(ix[0])
51
+
52
+ covariances = {}
53
+ for g in self.groups:
54
+ mean = sums[g] / Ng[g]
55
+ covariances[g] = second_moments[g] / Ng[g] - np.outer(mean, mean)
56
+
57
+ if len(self.groups) == 1:
58
+ covariances = list(covariances.values())[0]
59
+ self.parameters = self.format_parameters(covariances)
60
+
61
+ def format_parameters(self, covariances: Union[Dict, np.array]):
62
+ var_names = self.adata.var_names
63
+ def to_df(mat):
64
+ return pd.DataFrame(mat, index=var_names, columns=var_names)
65
+
66
+ if isinstance(covariances, dict):
67
+ formatted = {}
68
+ for k, v in covariances.items():
69
+ formatted[k] = to_df(v)
70
+ covariances = formatted
71
+ return covariances
72
+
73
+ if isinstance(covariances, (np.ndarray, list, tuple)):
74
+ covariances = to_df(covariances)
75
+ return covariances
76
+
77
+ def pseudo_obs(self, x_dict: Dict):
78
+ # convert one-hot encoding memberships to a map
79
+ # {"group1": [indices of group 1], "group2": [indices of group 2]}
80
+ memberships = x_dict.get("group").numpy()
81
+ group_ix = {g: np.where(memberships[:, self.group_col[g] == 1])[0] for g in self.groups}
82
+
83
+ # initialize the result
84
+ u = np.zeros((len(memberships), self.n_outcomes))
85
+ parameters = self.parameters
86
+ if type(parameters) is not dict:
87
+ parameters = {group: parameters}
88
+
89
+ # loop over groups and sample each part in turn
90
+ for group, sigma in parameters.items():
91
+ z = np.random.multivariate_normal(
92
+ mean=np.zeros(self.n_outcomes),
93
+ cov=sigma,
94
+ size=len(group_ix[group])
95
+ )
96
+ normal_distn = norm(0, np.diag(sigma) ** 0.5)
97
+ u[group_ix[group]] = normal_distn.cdf(z)
98
+ return u
99
+
100
+ def likelihood(self, uniformizer: Callable, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
101
+ # uniformize the observations
102
+ y, x_dict = batch
103
+ u = uniformizer(y, x_dict)
104
+ z = norm().ppf(u)
105
+
106
+ # same group manipulation as for pseudobs
107
+ parameters = self.parameters
108
+ if type(parameters) is not dict:
109
+ parameters = {group: parameters}
110
+
111
+ memberships = x_dict.get("group").numpy()
112
+ group_ix = {g: np.where(memberships[:, self.group_col[g] == 1])[0] for g in self.groups}
113
+ ll = np.zeros(len(z))
114
+ for group, sigma in parameters.items():
115
+ ix = group_ix[group]
116
+ if len(ix) > 0:
117
+ copula_ll = multivariate_normal.logpdf(z[ix], np.zeros(sigma.shape[0]), sigma)
118
+ ll[ix] = copula_ll - norm.logpdf(z[ix]).sum(axis=1)
119
+ return ll
120
+
121
+ def num_params(self, **kwargs):
122
+ S = self.parameters
123
+ per_group = [(np.sum(S[g].values != 0) - S[g].shape[0]) / 2 for g in self.groups]
124
+ return sum(per_group)