scdesigner 0.0.5__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. scdesigner/base/__init__.py +8 -0
  2. scdesigner/base/copula.py +416 -0
  3. scdesigner/base/marginal.py +391 -0
  4. scdesigner/base/simulator.py +59 -0
  5. scdesigner/copulas/__init__.py +8 -0
  6. scdesigner/copulas/standard_copula.py +645 -0
  7. scdesigner/datasets/__init__.py +5 -0
  8. scdesigner/datasets/pancreas.py +39 -0
  9. scdesigner/distributions/__init__.py +19 -0
  10. scdesigner/{minimal → distributions}/bernoulli.py +42 -14
  11. scdesigner/distributions/gaussian.py +114 -0
  12. scdesigner/distributions/negbin.py +121 -0
  13. scdesigner/distributions/negbin_irls.py +72 -0
  14. scdesigner/distributions/negbin_irls_funs.py +456 -0
  15. scdesigner/distributions/poisson.py +88 -0
  16. scdesigner/{minimal → distributions}/zero_inflated_negbin.py +39 -10
  17. scdesigner/distributions/zero_inflated_poisson.py +103 -0
  18. scdesigner/simulators/__init__.py +24 -28
  19. scdesigner/simulators/composite.py +239 -0
  20. scdesigner/simulators/positive_nonnegative_matrix_factorization.py +477 -0
  21. scdesigner/simulators/scd3.py +486 -0
  22. scdesigner/transform/__init__.py +8 -6
  23. scdesigner/{minimal → transform}/transform.py +1 -1
  24. scdesigner/{minimal → utils}/kwargs.py +4 -1
  25. {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/METADATA +1 -1
  26. scdesigner-0.0.10.dist-info/RECORD +28 -0
  27. {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/WHEEL +1 -1
  28. scdesigner/data/__init__.py +0 -16
  29. scdesigner/data/formula.py +0 -137
  30. scdesigner/data/group.py +0 -123
  31. scdesigner/data/sparse.py +0 -39
  32. scdesigner/diagnose/__init__.py +0 -65
  33. scdesigner/diagnose/aic_bic.py +0 -119
  34. scdesigner/diagnose/plot.py +0 -242
  35. scdesigner/estimators/__init__.py +0 -32
  36. scdesigner/estimators/bernoulli.py +0 -85
  37. scdesigner/estimators/gaussian.py +0 -121
  38. scdesigner/estimators/gaussian_copula_factory.py +0 -367
  39. scdesigner/estimators/glm_factory.py +0 -75
  40. scdesigner/estimators/negbin.py +0 -153
  41. scdesigner/estimators/pnmf.py +0 -160
  42. scdesigner/estimators/poisson.py +0 -124
  43. scdesigner/estimators/zero_inflated_negbin.py +0 -195
  44. scdesigner/estimators/zero_inflated_poisson.py +0 -85
  45. scdesigner/format/__init__.py +0 -4
  46. scdesigner/format/format.py +0 -20
  47. scdesigner/format/print.py +0 -30
  48. scdesigner/minimal/__init__.py +0 -17
  49. scdesigner/minimal/composite.py +0 -119
  50. scdesigner/minimal/copula.py +0 -205
  51. scdesigner/minimal/formula.py +0 -23
  52. scdesigner/minimal/gaussian.py +0 -65
  53. scdesigner/minimal/loader.py +0 -211
  54. scdesigner/minimal/marginal.py +0 -154
  55. scdesigner/minimal/negbin.py +0 -73
  56. scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -231
  57. scdesigner/minimal/scd3.py +0 -96
  58. scdesigner/minimal/scd3_instances.py +0 -50
  59. scdesigner/minimal/simulator.py +0 -25
  60. scdesigner/minimal/standard_copula.py +0 -383
  61. scdesigner/predictors/__init__.py +0 -15
  62. scdesigner/predictors/bernoulli.py +0 -9
  63. scdesigner/predictors/gaussian.py +0 -16
  64. scdesigner/predictors/negbin.py +0 -17
  65. scdesigner/predictors/poisson.py +0 -12
  66. scdesigner/predictors/zero_inflated_negbin.py +0 -18
  67. scdesigner/predictors/zero_inflated_poisson.py +0 -18
  68. scdesigner/samplers/__init__.py +0 -23
  69. scdesigner/samplers/bernoulli.py +0 -27
  70. scdesigner/samplers/gaussian.py +0 -25
  71. scdesigner/samplers/glm_factory.py +0 -103
  72. scdesigner/samplers/negbin.py +0 -25
  73. scdesigner/samplers/poisson.py +0 -25
  74. scdesigner/samplers/zero_inflated_negbin.py +0 -40
  75. scdesigner/samplers/zero_inflated_poisson.py +0 -16
  76. scdesigner/simulators/composite_regressor.py +0 -72
  77. scdesigner/simulators/glm_simulator.py +0 -167
  78. scdesigner/simulators/pnmf_regression.py +0 -61
  79. scdesigner/transform/amplify.py +0 -14
  80. scdesigner/transform/mask.py +0 -33
  81. scdesigner/transform/nullify.py +0 -25
  82. scdesigner/transform/split.py +0 -23
  83. scdesigner/transform/substitute.py +0 -14
  84. scdesigner-0.0.5.dist-info/RECORD +0 -66
@@ -1,367 +0,0 @@
1
- from ..data import stack_collate, multiple_formula_group_loader
2
- from .. import data
3
- from anndata import AnnData
4
- from collections.abc import Callable
5
- from typing import Union
6
- from scipy.stats import norm
7
- from torch.utils.data import DataLoader
8
- import numpy as np
9
- import pandas as pd
10
-
11
- ###############################################################################
12
- ## General copula factory functions
13
- ###############################################################################
14
-
15
-
16
- def gaussian_copula_array_factory(marginal_model: Callable, uniformizer: Callable):
17
- def copula_fun(loaders: dict[str, DataLoader], lr: float = 0.1, epochs: int = 40, **kwargs):
18
- # for the marginal model, ignore the groupings
19
- # Strip all dataloaders and create a dictionary to pass to marginal_model
20
- formula_loaders = {}
21
- for key in loaders.keys():
22
- formula_loaders[key] = strip_dataloader(loaders[key], pop="Stack" in type(loaders[key].dataset).__name__)
23
-
24
- # Call marginal_model with the dictionary of stripped dataloaders
25
- parameters = marginal_model(formula_loaders, lr=lr, epochs=epochs, **kwargs)
26
-
27
- # estimate covariance, allowing for different groups
28
- parameters["covariance"] = copula_covariance(parameters, loaders, uniformizer)
29
- return parameters
30
-
31
- return copula_fun
32
-
33
-
34
- def fast_gaussian_copula_array_factory(marginal_model: Callable, uniformizer: Callable, top_k: int):
35
- """
36
- Factory function for fast Gaussian copula array computation using top-k gene modeling.
37
-
38
- """
39
- def copula_fun(loaders: dict[str, DataLoader], lr: float = 0.1, epochs: int = 40, **kwargs):
40
- # for the marginal model, ignore the groupings
41
- # Strip all dataloaders and create a dictionary to pass to marginal_model
42
- formula_loaders = {}
43
- for key in loaders.keys():
44
- formula_loaders[key] = strip_dataloader(loaders[key], pop="Stack" in type(loaders[key].dataset).__name__)
45
-
46
- # Call marginal_model with the dictionary of stripped dataloaders
47
- parameters = marginal_model(formula_loaders, lr=lr, epochs=epochs, **kwargs)
48
-
49
- # estimate covariance using fast method, allowing for different groups
50
- parameters["covariance"] = fast_copula_covariance(parameters, loaders, uniformizer, top_k)
51
- return parameters
52
-
53
- return copula_fun
54
-
55
-
56
- def gaussian_copula_factory(copula_array_fun: Callable,
57
- parameter_formatter: Callable,
58
- param_name: list = None):
59
- def copula_fun(
60
- adata: AnnData,
61
- formula: Union[str, dict] = "~ 1",
62
- grouping_var: str = None,
63
- chunk_size: int = int(1e4),
64
- batch_size: int = 512,
65
- **kwargs
66
- ) -> dict:
67
-
68
- if param_name is not None:
69
- formula = data.standardize_formula(formula, param_name)
70
-
71
- dls = multiple_formula_group_loader(
72
- adata,
73
- formula,
74
- grouping_var,
75
- chunk_size=chunk_size,
76
- batch_size=batch_size,
77
- ) # returns a dictionary of dataloaders
78
- parameters = copula_array_fun(dls, **kwargs)
79
-
80
- # Pass the full dls to parameter_formatter so it can extract what it needs
81
- parameters = parameter_formatter(
82
- parameters, adata.var_names, dls
83
- )
84
- parameters["covariance"] = format_copula_parameters(parameters, adata.var_names)
85
- return parameters
86
-
87
- return copula_fun
88
-
89
-
90
-
91
-
92
- def copula_covariance(parameters: dict, loaders: dict[str, DataLoader], uniformizer: Callable):
93
-
94
- first_loader = next(iter(loaders.values()))
95
- D = next(iter(first_loader))[1].shape[1] #dimension of y
96
- groups = first_loader.dataset.groups # a list of strings of group names
97
- sums = {g: np.zeros(D) for g in groups}
98
- second_moments = {g: np.eye(D) for g in groups}
99
- Ng = {g: 0 for g in groups}
100
- keys = list(loaders.keys())
101
- loaders = list(loaders.values())
102
- num_keys = len(keys)
103
-
104
- for batches in zip(*loaders):
105
- x_batch_dict = {
106
- keys[i]: batches[i][0].cpu().numpy() for i in range(num_keys)
107
- }
108
- y_batch = batches[0][1].cpu().numpy()
109
- memberships = batches[0][2] # should be identical for all keys
110
-
111
- u = uniformizer(parameters, x_batch_dict, y_batch)
112
- for g in groups:
113
- ix = np.where(np.array(memberships) == g)
114
- z = norm().ppf(u[ix])
115
- second_moments[g] += z.T @ z
116
- sums[g] += z.sum(axis=0)
117
- Ng[g] += len(ix[0])
118
-
119
- result = {}
120
- for g in groups:
121
- mean = sums[g] / Ng[g]
122
- result[g] = second_moments[g] / Ng[g] - np.outer(mean, mean)
123
-
124
- if len(groups) == 1:
125
- return list(result.values())[0]
126
- return result
127
-
128
-
129
- def fast_copula_covariance(parameters: dict, loaders: dict[str, DataLoader], uniformizer: Callable, top_k: int):
130
- """
131
- Compute an efficient approximation of copula covariance by modeling only the top-k most prevalent genes
132
- with full covariance and approximating the rest with diagonal covariance.
133
-
134
- Parameters:
135
- -----------
136
- parameters : dict
137
- Model parameters dictionary
138
- loaders : dict[str, DataLoader]
139
- Dictionary of data loaders
140
- uniformizer : Callable
141
- Function to convert to uniform distribution
142
- top_k : int
143
- Number of top genes to model with full covariance
144
-
145
- Returns:
146
- --------
147
- dict or FastCovarianceStructure:
148
- - If single group: FastCovarianceStructure containing:
149
- * top_k_cov: (top_k, top_k) full covariance matrix for top genes
150
- * remaining_var: (remaining_genes,) diagonal variances for remaining genes
151
- * top_k_indices: indices of top-k genes
152
- * remaining_indices: indices of remaining genes
153
- * gene_total_expression: total expression levels for gene selection
154
- - If multiple groups: dict mapping group names to FastCovarianceStructure objects
155
- """
156
-
157
- first_loader = next(iter(loaders.values()))
158
- D = next(iter(first_loader))[1].shape[1] #dimension of y
159
- groups = first_loader.dataset.groups # a list of strings of group names
160
-
161
- # Validate top_k parameter
162
- if top_k <= 0:
163
- raise ValueError("top_k must be a positive integer")
164
- if top_k >= D:
165
- # If top_k is larger than total genes, fall back to regular covariance
166
- return copula_covariance(parameters, loaders, uniformizer)
167
-
168
- # Step 1: Calculate total expression for each gene to determine prevalence
169
- gene_total_expression = np.zeros(D)
170
-
171
- keys = list(loaders.keys())
172
- loaders_list = list(loaders.values())
173
- num_keys = len(keys)
174
-
175
- # Calculate total expression across all batches
176
- for batches in zip(*loaders_list):
177
- y_batch = batches[0][1].cpu().numpy()
178
- gene_total_expression += y_batch.sum(axis=0)
179
-
180
- # Step 2: Select top-k most prevalent genes
181
- top_k_indices = np.argsort(gene_total_expression)[-top_k:]
182
- remaining_indices = np.argsort(gene_total_expression)[:-top_k]
183
-
184
- # Step 3: Compute statistics for both top-k and remaining genes
185
- sums_top_k = {g: np.zeros(top_k) for g in groups}
186
- second_moments_top_k = {g: np.zeros((top_k, top_k)) for g in groups}
187
-
188
- sums_remaining = {g: np.zeros(len(remaining_indices)) for g in groups}
189
- second_moments_remaining = {g: np.zeros(len(remaining_indices)) for g in groups}
190
-
191
- Ng = {g: 0 for g in groups}
192
-
193
- # Reset loaders for second pass
194
- loaders_list = list(loaders.values())
195
-
196
- for batches in zip(*loaders_list):
197
- x_batch_dict = {
198
- keys[i]: batches[i][0].cpu().numpy() for i in range(num_keys)
199
- }
200
- y_batch = batches[0][1].cpu().numpy()
201
- memberships = batches[0][2] # should be identical for all keys
202
-
203
- u = uniformizer(parameters, x_batch_dict, y_batch)
204
-
205
- for g in groups:
206
- ix = np.where(np.array(memberships) == g)
207
- if len(ix[0]) == 0:
208
- continue
209
-
210
- z = norm().ppf(u[ix])
211
-
212
- # Process top-k genes with full covariance
213
- z_top_k = z[:, top_k_indices]
214
- second_moments_top_k[g] += z_top_k.T @ z_top_k
215
- sums_top_k[g] += z_top_k.sum(axis=0)
216
-
217
- # Process remaining genes with diagonal covariance only
218
- z_remaining = z[:, remaining_indices]
219
- second_moments_remaining[g] += (z_remaining ** 2).sum(axis=0)
220
- sums_remaining[g] += z_remaining.sum(axis=0)
221
-
222
- Ng[g] += len(ix[0])
223
-
224
- # Step 4: Compute final covariance structures
225
- result = {}
226
- for g in groups:
227
- if Ng[g] == 0:
228
- continue
229
-
230
- # Full covariance for top-k genes
231
- mean_top_k = sums_top_k[g] / Ng[g]
232
- cov_top_k = second_moments_top_k[g] / Ng[g] - np.outer(mean_top_k, mean_top_k)
233
-
234
- # Diagonal variance for remaining genes
235
- mean_remaining = sums_remaining[g] / Ng[g]
236
- var_remaining = second_moments_remaining[g] / Ng[g] - mean_remaining ** 2
237
-
238
- # Create FastCovarianceStructure
239
- result[g] = FastCovarianceStructure(
240
- top_k_cov=cov_top_k,
241
- remaining_var=var_remaining,
242
- top_k_indices=top_k_indices,
243
- remaining_indices=remaining_indices,
244
- gene_total_expression=gene_total_expression
245
- )
246
-
247
- if len(groups) == 1:
248
- return list(result.values())[0]
249
- return result
250
-
251
-
252
- class FastCovarianceStructure:
253
- """
254
- Data structure to efficiently store and access covariance information for fast copula sampling.
255
-
256
- Attributes:
257
- -----------
258
- top_k_cov : np.ndarray
259
- Full covariance matrix for top-k most prevalent genes, shape (top_k, top_k)
260
- remaining_var : np.ndarray
261
- Diagonal variances for remaining genes, shape (remaining_genes,)
262
- top_k_indices : np.ndarray
263
- Indices of the top-k genes in the original gene ordering
264
- remaining_indices : np.ndarray
265
- Indices of the remaining genes in the original gene ordering
266
- gene_total_expression : np.ndarray
267
- Total expression levels used for gene selection, shape (total_genes,)
268
- """
269
-
270
- def __init__(self, top_k_cov, remaining_var, top_k_indices, remaining_indices, gene_total_expression):
271
- self.top_k_cov = top_k_cov
272
- self.remaining_var = remaining_var
273
- self.top_k_indices = top_k_indices
274
- self.remaining_indices = remaining_indices
275
- self.gene_total_expression = gene_total_expression
276
- self.top_k = len(top_k_indices)
277
- self.total_genes = len(top_k_indices) + len(remaining_indices)
278
-
279
- def __repr__(self):
280
- return (f"FastCovarianceStructure(top_k={self.top_k}, "
281
- f"remaining_genes={len(self.remaining_indices)}, "
282
- f"total_genes={self.total_genes})")
283
-
284
- def to_full_matrix(self):
285
- """
286
- Convert to full covariance matrix for compatibility/debugging.
287
-
288
- Returns:
289
- --------
290
- np.ndarray : Full covariance matrix with shape (total_genes, total_genes)
291
- """
292
- full_cov = np.zeros((self.total_genes, self.total_genes))
293
-
294
- # Fill in top-k block
295
- ix_top = np.ix_(self.top_k_indices, self.top_k_indices)
296
- full_cov[ix_top] = self.top_k_cov
297
-
298
- # Fill in diagonal for remaining genes
299
- full_cov[self.remaining_indices, self.remaining_indices] = self.remaining_var
300
-
301
- return full_cov
302
-
303
-
304
-
305
- ###############################################################################
306
- ## Helpers to prepare and postprocess copula parameters
307
- ###############################################################################
308
-
309
-
310
- def group_indices(grouping_var: str, obs: pd.DataFrame) -> dict:
311
- """
312
- Returns a dictionary of group indices for each group in the grouping variable.
313
- """
314
- if grouping_var is None:
315
- grouping_var = "_copula_group"
316
- if "copula_group" not in obs.columns:
317
- obs["_copula_group"] = pd.Categorical(["shared_group"] * len(obs))
318
- result = {}
319
-
320
- for group in list(obs[grouping_var].dtype.categories):
321
- result[group] = np.where(obs[grouping_var].values == group)[0]
322
- return result
323
-
324
-
325
- def clip(u: np.array, min: float = 1e-5, max: float = 1 - 1e-5) -> np.array:
326
- u[u < min] = min
327
- u[u > max] = max
328
- return u
329
-
330
-
331
- def format_copula_parameters(parameters: dict, var_names: list):
332
- '''
333
- Format the copula parameters into a dictionary of covariance matrices in pandas dataframe format.
334
- If the covariance is a FastCovarianceStructure, return it as is.
335
- If the covariance is a dictionary of FastCovarianceStructure objects, return it as is.
336
- Otherwise, return a dictionary of covariance matrices in pandas dataframe format.
337
- '''
338
- covariance = parameters["covariance"]
339
-
340
- # Handle FastCovarianceStructure - keep it as is since it has efficient methods
341
- if isinstance(covariance, FastCovarianceStructure):
342
- return covariance
343
- elif isinstance(covariance, dict) and any(isinstance(v, FastCovarianceStructure) for v in covariance.values()):
344
- # If it's a dict containing FastCovarianceStructure objects, keep as is
345
- return covariance
346
- elif type(covariance) is not dict:
347
- covariance = pd.DataFrame(
348
- parameters["covariance"], columns=list(var_names), index=list(var_names)
349
- )
350
- else:
351
- for group in covariance.keys():
352
- if not isinstance(covariance[group], FastCovarianceStructure):
353
- covariance[group] = pd.DataFrame(
354
- parameters["covariance"][group],
355
- columns=list(var_names),
356
- index=list(var_names),
357
- )
358
- return covariance
359
-
360
-
361
- def strip_dataloader(dataloader, pop=False):
362
- return DataLoader(
363
- dataset=dataloader.dataset,
364
- batch_sampler=dataloader.batch_sampler,
365
- collate_fn=stack_collate(pop=pop, groups=False),
366
- )
367
-
@@ -1,75 +0,0 @@
1
- from tqdm import tqdm
2
- from torch.utils.data import DataLoader
3
- import torch
4
-
5
-
6
- def glm_regression_factory(likelihood, initializer, postprocessor) -> dict:
7
- def estimator(
8
- dataloader: DataLoader,
9
- lr: float = 0.1,
10
- epochs: int = 40,
11
- ):
12
- device = check_device()
13
- x, y = next(iter(dataloader))
14
- params = initializer(x, y, device)
15
- optimizer = torch.optim.Adam([params], lr=lr)
16
-
17
- for epoch in range(epochs):
18
- for x_batch, y_batch in (pbar := tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}", leave=False)):
19
- optimizer.zero_grad()
20
- loss = likelihood(params, x_batch, y_batch)
21
- loss.backward()
22
- optimizer.step()
23
- pbar.set_postfix_str(f"loss: {loss.item()}")
24
-
25
-
26
- return postprocessor(params, x.shape[1], y.shape[1])
27
-
28
- return estimator
29
-
30
- def multiple_formula_regression_factory(likelihood, initializer, postprocessor) -> dict:
31
- def estimator(
32
- dataloaders: dict[str, DataLoader],
33
- lr: float = 0.1,
34
- epochs: int = 40,
35
- ):
36
- device = check_device()
37
- x_dict = {}
38
- y_dict = {}
39
- for key in dataloaders.keys():
40
- x_dict[key], y_dict[key] = next(iter(dataloaders[key]))
41
- # check if all ys are the same
42
- y_ref = y_dict[list(dataloaders.keys())[0]]
43
- for key in dataloaders.keys():
44
- if not torch.equal(y_dict[key], y_ref):
45
- raise ValueError(f"Ys are not the same for {key}")
46
- params = initializer(x_dict, y_ref, device) # x is a dictionary of tensors, y is a tensor
47
- optimizer = torch.optim.Adam([params], lr=lr)
48
-
49
- keys = list(dataloaders.keys())
50
- loaders = list(dataloaders.values())
51
-
52
- for epoch in range(epochs):
53
- num_keys = len(keys)
54
- for batches in (pbar := tqdm(zip(*loaders), desc=f"Epoch {epoch + 1}/{epochs}", leave=False)):
55
- x_batch_dict = {
56
- keys[i]: batches[i][0].to(device) for i in range(num_keys)
57
- }
58
- y_batch = batches[0][1].to(device)
59
- optimizer.zero_grad()
60
- loss = likelihood(params, x_batch_dict, y_batch)
61
- loss.backward()
62
- optimizer.step()
63
- pbar.set_postfix_str(f"loss: {loss.item()}")
64
-
65
- return postprocessor(params, x_dict, y_ref)
66
-
67
- return estimator
68
-
69
-
70
- def check_device():
71
- return torch.device(
72
- "cuda"
73
- if torch.cuda.is_available()
74
- else "mps" if torch.backends.mps.is_available() else "cpu"
75
- )
@@ -1,153 +0,0 @@
1
- from . import gaussian_copula_factory as gcf
2
- from . import glm_factory as factory
3
- from .. import format
4
- from .. import data
5
- from anndata import AnnData
6
- from scipy.stats import nbinom
7
- import numpy as np
8
- import pandas as pd
9
- import torch
10
- from typing import Union
11
-
12
- ###############################################################################
13
- ## Regression functions that operate on numpy arrays
14
- ###############################################################################
15
-
16
-
17
- def negbin_regression_likelihood(params, X_dict, y):
18
- n_mean_features = X_dict["mean"].shape[1]
19
- n_dispersion_features = X_dict["dispersion"].shape[1]
20
- n_outcomes = y.shape[1]
21
-
22
- # form the mean and dispersion parameters
23
- coef_mean = params[: n_mean_features * n_outcomes].\
24
- reshape(n_mean_features, n_outcomes)
25
- coef_dispersion = params[n_mean_features * n_outcomes :].\
26
- reshape(n_dispersion_features, n_outcomes)
27
- r = torch.exp(X_dict["dispersion"] @ coef_dispersion)
28
- mu = torch.exp(X_dict["mean"] @ coef_mean)
29
-
30
- # compute the negative log likelihood
31
- log_likelihood = (
32
- torch.lgamma(y + r)
33
- - torch.lgamma(r)
34
- - torch.lgamma(y + 1)
35
- + r * torch.log(r)
36
- + y * torch.log(mu)
37
- - (r + y) * torch.log(r + mu)
38
- )
39
-
40
- return -torch.sum(log_likelihood)
41
-
42
-
43
- def negbin_initializer(x_dict, y, device):
44
- n_mean_features = x_dict["mean"].shape[1]
45
- n_outcomes = y.shape[1]
46
- n_dispersion_features = x_dict["dispersion"].shape[1]
47
- return torch.zeros(
48
- n_mean_features * n_outcomes\
49
- + n_dispersion_features * n_outcomes,
50
- requires_grad=True, device=device
51
- )
52
-
53
-
54
- def negbin_postprocessor(params, x_dict, y):
55
- n_mean_features = x_dict["mean"].shape[1]
56
- n_outcomes = y.shape[1]
57
- n_dispersion_features = x_dict["dispersion"].shape[1]
58
- coef_mean = format.to_np(params[:n_mean_features * n_outcomes]).\
59
- reshape(n_mean_features, n_outcomes)
60
- coef_dispersion = format.to_np(params[n_mean_features * n_outcomes:]).\
61
- reshape(n_dispersion_features, n_outcomes)
62
- return {"coef_mean": coef_mean, "coef_dispersion": coef_dispersion}
63
-
64
-
65
- negbin_regression_array = factory.multiple_formula_regression_factory(
66
- negbin_regression_likelihood, negbin_initializer, negbin_postprocessor
67
- )
68
-
69
-
70
- ###############################################################################
71
- ## Regression functions that operate on AnnData objects
72
- ###############################################################################
73
-
74
- def format_negbin_parameters(
75
- parameters: dict, var_names: list, mean_coef_index: list,
76
- dispersion_coef_index: list
77
- ) -> dict:
78
- parameters["coef_mean"] = pd.DataFrame(
79
- parameters["coef_mean"], columns=var_names, index=mean_coef_index
80
- )
81
- parameters["coef_dispersion"] = pd.DataFrame(
82
- parameters["coef_dispersion"], columns=var_names, index=dispersion_coef_index
83
- )
84
- return parameters
85
-
86
- def format_negbin_parameters_with_loaders(
87
- parameters: dict, var_names: list, dls: dict
88
- ) -> dict:
89
- # Extract the coefficient indices from the dataloaders
90
- mean_coef_index = dls["mean"].dataset.x_names
91
- dispersion_coef_index = dls["dispersion"].dataset.x_names
92
-
93
- return format_negbin_parameters(parameters, var_names, mean_coef_index, dispersion_coef_index)
94
-
95
- def negbin_regression(
96
- adata: AnnData, formula: Union[str, dict], chunk_size: int = int(1e4), batch_size=512, **kwargs
97
- ) -> dict:
98
- formula = data.standardize_formula(formula, allowed_keys=['mean', 'dispersion'])
99
-
100
- loaders = data.multiple_formula_loader(
101
- adata, formula, chunk_size=chunk_size, batch_size=batch_size
102
- )
103
- parameters = negbin_regression_array(loaders, **kwargs)
104
- return format_negbin_parameters(
105
- parameters, list(adata.var_names), loaders["mean"].dataset.x_names, loaders["dispersion"].dataset.x_names
106
- )
107
-
108
- ###############################################################################
109
- ## Copula versions for negative binomial regression
110
- ###############################################################################
111
-
112
-
113
- def negbin_uniformizer(parameters, X_dict, y, epsilon=1e-3):
114
- r = np.exp(X_dict["dispersion"] @ parameters["coef_dispersion"])
115
- mu = np.exp(X_dict["mean"] @ parameters["coef_mean"])
116
- u1 = nbinom(n=r, p=r / (r + mu)).cdf(y)
117
- u2 = np.where(y > 0, nbinom(n=r, p=r / (r + mu)).cdf(y - 1), 0)
118
- v = np.random.uniform(size=y.shape)
119
- return np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
120
-
121
-
122
- negbin_copula_array = gcf.gaussian_copula_array_factory(
123
- negbin_regression_array, negbin_uniformizer
124
- ) # should accept a dictionary of dataloaders
125
-
126
- negbin_copula = gcf.gaussian_copula_factory(
127
- negbin_copula_array, format_negbin_parameters_with_loaders,
128
- param_name=['mean', 'dispersion']
129
- )
130
-
131
- ###############################################################################
132
- ## Fast copula versions for negative binomial regression
133
- ###############################################################################
134
-
135
- def fast_negbin_copula_array_factory(top_k: int):
136
- """
137
- top_k: int
138
- Number of top genes to model with full covariance
139
- """
140
- return gcf.fast_gaussian_copula_array_factory(
141
- negbin_regression_array, negbin_uniformizer, top_k
142
- )
143
-
144
- def fast_negbin_copula_factory(top_k: int):
145
- """
146
- top_k: int
147
- Number of top genes to model with full covariance
148
- """
149
- fast_copula_array = fast_negbin_copula_array_factory(top_k)
150
- return gcf.gaussian_copula_factory(
151
- fast_copula_array, format_negbin_parameters_with_loaders,
152
- param_name=['mean', 'dispersion']
153
- )