scdesigner 0.0.3__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scdesigner might be problematic. Click here for more details.

Files changed (74) hide show
  1. {scdesigner-0.0.3 → scdesigner-0.0.4}/PKG-INFO +1 -1
  2. {scdesigner-0.0.3 → scdesigner-0.0.4}/pyproject.toml +1 -1
  3. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/estimators/__init__.py +8 -3
  4. scdesigner-0.0.4/src/scdesigner/estimators/gaussian_copula_factory.py +367 -0
  5. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/estimators/negbin.py +24 -0
  6. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/estimators/poisson.py +24 -0
  7. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/composite.py +2 -2
  8. scdesigner-0.0.4/src/scdesigner/minimal/copula.py +205 -0
  9. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/marginal.py +24 -19
  10. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/negbin.py +1 -1
  11. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/scd3.py +1 -0
  12. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/scd3_instances.py +5 -5
  13. scdesigner-0.0.4/src/scdesigner/minimal/standard_copula.py +383 -0
  14. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/transform.py +27 -30
  15. scdesigner-0.0.4/src/scdesigner/samplers/glm_factory.py +103 -0
  16. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/transform/nullify.py +1 -1
  17. scdesigner-0.0.3/src/scdesigner/estimators/gaussian_copula_factory.py +0 -152
  18. scdesigner-0.0.3/src/scdesigner/minimal/copula.py +0 -33
  19. scdesigner-0.0.3/src/scdesigner/minimal/standard_covariance.py +0 -124
  20. scdesigner-0.0.3/src/scdesigner/samplers/glm_factory.py +0 -41
  21. {scdesigner-0.0.3 → scdesigner-0.0.4}/.gitignore +0 -0
  22. {scdesigner-0.0.3 → scdesigner-0.0.4}/README.md +0 -0
  23. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/__init__.py +0 -0
  24. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/data/__init__.py +0 -0
  25. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/data/formula.py +0 -0
  26. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/data/group.py +0 -0
  27. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/data/sparse.py +0 -0
  28. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/diagnose/__init__.py +0 -0
  29. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/diagnose/aic_bic.py +0 -0
  30. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/diagnose/plot.py +0 -0
  31. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/estimators/bernoulli.py +0 -0
  32. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/estimators/gaussian.py +0 -0
  33. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/estimators/glm_factory.py +0 -0
  34. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/estimators/pnmf.py +0 -0
  35. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/estimators/zero_inflated_negbin.py +0 -0
  36. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/estimators/zero_inflated_poisson.py +0 -0
  37. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/format/__init__.py +0 -0
  38. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/format/format.py +0 -0
  39. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/format/print.py +0 -0
  40. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/__init__.py +0 -0
  41. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/bernoulli.py +0 -0
  42. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/formula.py +0 -0
  43. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/gaussian.py +0 -0
  44. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/kwargs.py +0 -0
  45. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/loader.py +0 -0
  46. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -0
  47. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/simulator.py +0 -0
  48. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/minimal/zero_inflated_negbin.py +0 -0
  49. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/predictors/__init__.py +0 -0
  50. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/predictors/bernoulli.py +0 -0
  51. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/predictors/gaussian.py +0 -0
  52. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/predictors/negbin.py +0 -0
  53. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/predictors/poisson.py +0 -0
  54. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/predictors/zero_inflated_negbin.py +0 -0
  55. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/predictors/zero_inflated_poisson.py +0 -0
  56. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/samplers/__init__.py +0 -0
  57. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/samplers/bernoulli.py +0 -0
  58. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/samplers/gaussian.py +0 -0
  59. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/samplers/negbin.py +0 -0
  60. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/samplers/poisson.py +0 -0
  61. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/samplers/zero_inflated_negbin.py +0 -0
  62. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/samplers/zero_inflated_poisson.py +0 -0
  63. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/simulators/__init__.py +0 -0
  64. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/simulators/composite_regressor.py +0 -0
  65. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/simulators/glm_simulator.py +0 -0
  66. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/simulators/pnmf_regression.py +0 -0
  67. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/transform/__init__.py +0 -0
  68. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/transform/amplify.py +0 -0
  69. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/transform/mask.py +0 -0
  70. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/transform/split.py +0 -0
  71. {scdesigner-0.0.3 → scdesigner-0.0.4}/src/scdesigner/transform/substitute.py +0 -0
  72. {scdesigner-0.0.3 → scdesigner-0.0.4}/tests/__init__.py +0 -0
  73. {scdesigner-0.0.3 → scdesigner-0.0.4}/tests/test_negative_binomial.py +0 -0
  74. {scdesigner-0.0.3 → scdesigner-0.0.4}/tests/test_simulator.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: scdesigner
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: Interactive simulation for rigorous and transparent multi-omics analysis.
5
5
  Project-URL: Homepage, https://github.com/krisrs1128/scDesigner/
6
6
  Project-URL: Issues, https://github.com/krisrs1128/scDesigner/Issues/
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "scdesigner"
3
- version = "0.0.3"
3
+ version = "0.0.4"
4
4
  authors = [
5
5
  { name="Kris Sankaran", email="ksankaran@wisc.edu" },
6
6
  ]
@@ -1,6 +1,6 @@
1
- from .negbin import negbin_regression, negbin_copula
2
- from .gaussian_copula_factory import group_indices
3
- from .poisson import poisson_regression, poisson_copula
1
+ from .negbin import negbin_regression, negbin_copula, fast_negbin_copula_factory
2
+ from .gaussian_copula_factory import group_indices, fast_copula_covariance, FastCovarianceStructure, fast_gaussian_copula_array_factory
3
+ from .poisson import poisson_regression, poisson_copula, fast_poisson_copula_factory
4
4
  from .bernoulli import bernoulli_regression, bernoulli_copula
5
5
  from .gaussian import gaussian_regression, gaussian_copula
6
6
  from .zero_inflated_negbin import (
@@ -24,4 +24,9 @@ __all__ = [
24
24
  "zero_inflated_negbin_regression",
25
25
  "zero_inflated_poisson_regression",
26
26
  "multiple_formula_regression_factory",
27
+ "fast_copula_covariance",
28
+ "FastCovarianceStructure",
29
+ "fast_gaussian_copula_array_factory",
30
+ "fast_negbin_copula_factory",
31
+ "fast_poisson_copula_factory",
27
32
  ]
@@ -0,0 +1,367 @@
1
+ from ..data import stack_collate, multiple_formula_group_loader
2
+ from .. import data
3
+ from anndata import AnnData
4
+ from collections.abc import Callable
5
+ from typing import Union
6
+ from scipy.stats import norm
7
+ from torch.utils.data import DataLoader
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ ###############################################################################
12
+ ## General copula factory functions
13
+ ###############################################################################
14
+
15
+
16
+ def gaussian_copula_array_factory(marginal_model: Callable, uniformizer: Callable):
17
+ def copula_fun(loaders: dict[str, DataLoader], lr: float = 0.1, epochs: int = 40, **kwargs):
18
+ # for the marginal model, ignore the groupings
19
+ # Strip all dataloaders and create a dictionary to pass to marginal_model
20
+ formula_loaders = {}
21
+ for key in loaders.keys():
22
+ formula_loaders[key] = strip_dataloader(loaders[key], pop="Stack" in type(loaders[key].dataset).__name__)
23
+
24
+ # Call marginal_model with the dictionary of stripped dataloaders
25
+ parameters = marginal_model(formula_loaders, lr=lr, epochs=epochs, **kwargs)
26
+
27
+ # estimate covariance, allowing for different groups
28
+ parameters["covariance"] = copula_covariance(parameters, loaders, uniformizer)
29
+ return parameters
30
+
31
+ return copula_fun
32
+
33
+
34
+ def fast_gaussian_copula_array_factory(marginal_model: Callable, uniformizer: Callable, top_k: int):
35
+ """
36
+ Factory function for fast Gaussian copula array computation using top-k gene modeling.
37
+
38
+ """
39
+ def copula_fun(loaders: dict[str, DataLoader], lr: float = 0.1, epochs: int = 40, **kwargs):
40
+ # for the marginal model, ignore the groupings
41
+ # Strip all dataloaders and create a dictionary to pass to marginal_model
42
+ formula_loaders = {}
43
+ for key in loaders.keys():
44
+ formula_loaders[key] = strip_dataloader(loaders[key], pop="Stack" in type(loaders[key].dataset).__name__)
45
+
46
+ # Call marginal_model with the dictionary of stripped dataloaders
47
+ parameters = marginal_model(formula_loaders, lr=lr, epochs=epochs, **kwargs)
48
+
49
+ # estimate covariance using fast method, allowing for different groups
50
+ parameters["covariance"] = fast_copula_covariance(parameters, loaders, uniformizer, top_k)
51
+ return parameters
52
+
53
+ return copula_fun
54
+
55
+
56
+ def gaussian_copula_factory(copula_array_fun: Callable,
57
+ parameter_formatter: Callable,
58
+ param_name: list = None):
59
+ def copula_fun(
60
+ adata: AnnData,
61
+ formula: Union[str, dict] = "~ 1",
62
+ grouping_var: str = None,
63
+ chunk_size: int = int(1e4),
64
+ batch_size: int = 512,
65
+ **kwargs
66
+ ) -> dict:
67
+
68
+ if param_name is not None:
69
+ formula = data.standardize_formula(formula, param_name)
70
+
71
+ dls = multiple_formula_group_loader(
72
+ adata,
73
+ formula,
74
+ grouping_var,
75
+ chunk_size=chunk_size,
76
+ batch_size=batch_size,
77
+ ) # returns a dictionary of dataloaders
78
+ parameters = copula_array_fun(dls, **kwargs)
79
+
80
+ # Pass the full dls to parameter_formatter so it can extract what it needs
81
+ parameters = parameter_formatter(
82
+ parameters, adata.var_names, dls
83
+ )
84
+ parameters["covariance"] = format_copula_parameters(parameters, adata.var_names)
85
+ return parameters
86
+
87
+ return copula_fun
88
+
89
+
90
+
91
+
92
+ def copula_covariance(parameters: dict, loaders: dict[str, DataLoader], uniformizer: Callable):
93
+
94
+ first_loader = next(iter(loaders.values()))
95
+ D = next(iter(first_loader))[1].shape[1] #dimension of y
96
+ groups = first_loader.dataset.groups # a list of strings of group names
97
+ sums = {g: np.zeros(D) for g in groups}
98
+ second_moments = {g: np.eye(D) for g in groups}
99
+ Ng = {g: 0 for g in groups}
100
+ keys = list(loaders.keys())
101
+ loaders = list(loaders.values())
102
+ num_keys = len(keys)
103
+
104
+ for batches in zip(*loaders):
105
+ x_batch_dict = {
106
+ keys[i]: batches[i][0].cpu().numpy() for i in range(num_keys)
107
+ }
108
+ y_batch = batches[0][1].cpu().numpy()
109
+ memberships = batches[0][2] # should be identical for all keys
110
+
111
+ u = uniformizer(parameters, x_batch_dict, y_batch)
112
+ for g in groups:
113
+ ix = np.where(np.array(memberships) == g)
114
+ z = norm().ppf(u[ix])
115
+ second_moments[g] += z.T @ z
116
+ sums[g] += z.sum(axis=0)
117
+ Ng[g] += len(ix[0])
118
+
119
+ result = {}
120
+ for g in groups:
121
+ mean = sums[g] / Ng[g]
122
+ result[g] = second_moments[g] / Ng[g] - np.outer(mean, mean)
123
+
124
+ if len(groups) == 1:
125
+ return list(result.values())[0]
126
+ return result
127
+
128
+
129
+ def fast_copula_covariance(parameters: dict, loaders: dict[str, DataLoader], uniformizer: Callable, top_k: int):
130
+ """
131
+ Compute an efficient approximation of copula covariance by modeling only the top-k most prevalent genes
132
+ with full covariance and approximating the rest with diagonal covariance.
133
+
134
+ Parameters:
135
+ -----------
136
+ parameters : dict
137
+ Model parameters dictionary
138
+ loaders : dict[str, DataLoader]
139
+ Dictionary of data loaders
140
+ uniformizer : Callable
141
+ Function to convert to uniform distribution
142
+ top_k : int
143
+ Number of top genes to model with full covariance
144
+
145
+ Returns:
146
+ --------
147
+ dict or FastCovarianceStructure:
148
+ - If single group: FastCovarianceStructure containing:
149
+ * top_k_cov: (top_k, top_k) full covariance matrix for top genes
150
+ * remaining_var: (remaining_genes,) diagonal variances for remaining genes
151
+ * top_k_indices: indices of top-k genes
152
+ * remaining_indices: indices of remaining genes
153
+ * gene_total_expression: total expression levels for gene selection
154
+ - If multiple groups: dict mapping group names to FastCovarianceStructure objects
155
+ """
156
+
157
+ first_loader = next(iter(loaders.values()))
158
+ D = next(iter(first_loader))[1].shape[1] #dimension of y
159
+ groups = first_loader.dataset.groups # a list of strings of group names
160
+
161
+ # Validate top_k parameter
162
+ if top_k <= 0:
163
+ raise ValueError("top_k must be a positive integer")
164
+ if top_k >= D:
165
+ # If top_k is larger than total genes, fall back to regular covariance
166
+ return copula_covariance(parameters, loaders, uniformizer)
167
+
168
+ # Step 1: Calculate total expression for each gene to determine prevalence
169
+ gene_total_expression = np.zeros(D)
170
+
171
+ keys = list(loaders.keys())
172
+ loaders_list = list(loaders.values())
173
+ num_keys = len(keys)
174
+
175
+ # Calculate total expression across all batches
176
+ for batches in zip(*loaders_list):
177
+ y_batch = batches[0][1].cpu().numpy()
178
+ gene_total_expression += y_batch.sum(axis=0)
179
+
180
+ # Step 2: Select top-k most prevalent genes
181
+ top_k_indices = np.argsort(gene_total_expression)[-top_k:]
182
+ remaining_indices = np.argsort(gene_total_expression)[:-top_k]
183
+
184
+ # Step 3: Compute statistics for both top-k and remaining genes
185
+ sums_top_k = {g: np.zeros(top_k) for g in groups}
186
+ second_moments_top_k = {g: np.zeros((top_k, top_k)) for g in groups}
187
+
188
+ sums_remaining = {g: np.zeros(len(remaining_indices)) for g in groups}
189
+ second_moments_remaining = {g: np.zeros(len(remaining_indices)) for g in groups}
190
+
191
+ Ng = {g: 0 for g in groups}
192
+
193
+ # Reset loaders for second pass
194
+ loaders_list = list(loaders.values())
195
+
196
+ for batches in zip(*loaders_list):
197
+ x_batch_dict = {
198
+ keys[i]: batches[i][0].cpu().numpy() for i in range(num_keys)
199
+ }
200
+ y_batch = batches[0][1].cpu().numpy()
201
+ memberships = batches[0][2] # should be identical for all keys
202
+
203
+ u = uniformizer(parameters, x_batch_dict, y_batch)
204
+
205
+ for g in groups:
206
+ ix = np.where(np.array(memberships) == g)
207
+ if len(ix[0]) == 0:
208
+ continue
209
+
210
+ z = norm().ppf(u[ix])
211
+
212
+ # Process top-k genes with full covariance
213
+ z_top_k = z[:, top_k_indices]
214
+ second_moments_top_k[g] += z_top_k.T @ z_top_k
215
+ sums_top_k[g] += z_top_k.sum(axis=0)
216
+
217
+ # Process remaining genes with diagonal covariance only
218
+ z_remaining = z[:, remaining_indices]
219
+ second_moments_remaining[g] += (z_remaining ** 2).sum(axis=0)
220
+ sums_remaining[g] += z_remaining.sum(axis=0)
221
+
222
+ Ng[g] += len(ix[0])
223
+
224
+ # Step 4: Compute final covariance structures
225
+ result = {}
226
+ for g in groups:
227
+ if Ng[g] == 0:
228
+ continue
229
+
230
+ # Full covariance for top-k genes
231
+ mean_top_k = sums_top_k[g] / Ng[g]
232
+ cov_top_k = second_moments_top_k[g] / Ng[g] - np.outer(mean_top_k, mean_top_k)
233
+
234
+ # Diagonal variance for remaining genes
235
+ mean_remaining = sums_remaining[g] / Ng[g]
236
+ var_remaining = second_moments_remaining[g] / Ng[g] - mean_remaining ** 2
237
+
238
+ # Create FastCovarianceStructure
239
+ result[g] = FastCovarianceStructure(
240
+ top_k_cov=cov_top_k,
241
+ remaining_var=var_remaining,
242
+ top_k_indices=top_k_indices,
243
+ remaining_indices=remaining_indices,
244
+ gene_total_expression=gene_total_expression
245
+ )
246
+
247
+ if len(groups) == 1:
248
+ return list(result.values())[0]
249
+ return result
250
+
251
+
252
+ class FastCovarianceStructure:
253
+ """
254
+ Data structure to efficiently store and access covariance information for fast copula sampling.
255
+
256
+ Attributes:
257
+ -----------
258
+ top_k_cov : np.ndarray
259
+ Full covariance matrix for top-k most prevalent genes, shape (top_k, top_k)
260
+ remaining_var : np.ndarray
261
+ Diagonal variances for remaining genes, shape (remaining_genes,)
262
+ top_k_indices : np.ndarray
263
+ Indices of the top-k genes in the original gene ordering
264
+ remaining_indices : np.ndarray
265
+ Indices of the remaining genes in the original gene ordering
266
+ gene_total_expression : np.ndarray
267
+ Total expression levels used for gene selection, shape (total_genes,)
268
+ """
269
+
270
+ def __init__(self, top_k_cov, remaining_var, top_k_indices, remaining_indices, gene_total_expression):
271
+ self.top_k_cov = top_k_cov
272
+ self.remaining_var = remaining_var
273
+ self.top_k_indices = top_k_indices
274
+ self.remaining_indices = remaining_indices
275
+ self.gene_total_expression = gene_total_expression
276
+ self.top_k = len(top_k_indices)
277
+ self.total_genes = len(top_k_indices) + len(remaining_indices)
278
+
279
+ def __repr__(self):
280
+ return (f"FastCovarianceStructure(top_k={self.top_k}, "
281
+ f"remaining_genes={len(self.remaining_indices)}, "
282
+ f"total_genes={self.total_genes})")
283
+
284
+ def to_full_matrix(self):
285
+ """
286
+ Convert to full covariance matrix for compatibility/debugging.
287
+
288
+ Returns:
289
+ --------
290
+ np.ndarray : Full covariance matrix with shape (total_genes, total_genes)
291
+ """
292
+ full_cov = np.zeros((self.total_genes, self.total_genes))
293
+
294
+ # Fill in top-k block
295
+ ix_top = np.ix_(self.top_k_indices, self.top_k_indices)
296
+ full_cov[ix_top] = self.top_k_cov
297
+
298
+ # Fill in diagonal for remaining genes
299
+ full_cov[self.remaining_indices, self.remaining_indices] = self.remaining_var
300
+
301
+ return full_cov
302
+
303
+
304
+
305
+ ###############################################################################
306
+ ## Helpers to prepare and postprocess copula parameters
307
+ ###############################################################################
308
+
309
+
310
+ def group_indices(grouping_var: str, obs: pd.DataFrame) -> dict:
311
+ """
312
+ Returns a dictionary of group indices for each group in the grouping variable.
313
+ """
314
+ if grouping_var is None:
315
+ grouping_var = "_copula_group"
316
+ if "copula_group" not in obs.columns:
317
+ obs["_copula_group"] = pd.Categorical(["shared_group"] * len(obs))
318
+ result = {}
319
+
320
+ for group in list(obs[grouping_var].dtype.categories):
321
+ result[group] = np.where(obs[grouping_var].values == group)[0]
322
+ return result
323
+
324
+
325
+ def clip(u: np.array, min: float = 1e-5, max: float = 1 - 1e-5) -> np.array:
326
+ u[u < min] = min
327
+ u[u > max] = max
328
+ return u
329
+
330
+
331
+ def format_copula_parameters(parameters: dict, var_names: list):
332
+ '''
333
+ Format the copula parameters into a dictionary of covariance matrices in pandas dataframe format.
334
+ If the covariance is a FastCovarianceStructure, return it as is.
335
+ If the covariance is a dictionary of FastCovarianceStructure objects, return it as is.
336
+ Otherwise, return a dictionary of covariance matrices in pandas dataframe format.
337
+ '''
338
+ covariance = parameters["covariance"]
339
+
340
+ # Handle FastCovarianceStructure - keep it as is since it has efficient methods
341
+ if isinstance(covariance, FastCovarianceStructure):
342
+ return covariance
343
+ elif isinstance(covariance, dict) and any(isinstance(v, FastCovarianceStructure) for v in covariance.values()):
344
+ # If it's a dict containing FastCovarianceStructure objects, keep as is
345
+ return covariance
346
+ elif type(covariance) is not dict:
347
+ covariance = pd.DataFrame(
348
+ parameters["covariance"], columns=list(var_names), index=list(var_names)
349
+ )
350
+ else:
351
+ for group in covariance.keys():
352
+ if not isinstance(covariance[group], FastCovarianceStructure):
353
+ covariance[group] = pd.DataFrame(
354
+ parameters["covariance"][group],
355
+ columns=list(var_names),
356
+ index=list(var_names),
357
+ )
358
+ return covariance
359
+
360
+
361
+ def strip_dataloader(dataloader, pop=False):
362
+ return DataLoader(
363
+ dataset=dataloader.dataset,
364
+ batch_sampler=dataloader.batch_sampler,
365
+ collate_fn=stack_collate(pop=pop, groups=False),
366
+ )
367
+
@@ -127,3 +127,27 @@ negbin_copula = gcf.gaussian_copula_factory(
127
127
  negbin_copula_array, format_negbin_parameters_with_loaders,
128
128
  param_name=['mean', 'dispersion']
129
129
  )
130
+
131
+ ###############################################################################
132
+ ## Fast copula versions for negative binomial regression
133
+ ###############################################################################
134
+
135
+ def fast_negbin_copula_array_factory(top_k: int):
136
+ """
137
+ top_k: int
138
+ Number of top genes to model with full covariance
139
+ """
140
+ return gcf.fast_gaussian_copula_array_factory(
141
+ negbin_regression_array, negbin_uniformizer, top_k
142
+ )
143
+
144
+ def fast_negbin_copula_factory(top_k: int):
145
+ """
146
+ top_k: int
147
+ Number of top genes to model with full covariance
148
+ """
149
+ fast_copula_array = fast_negbin_copula_array_factory(top_k)
150
+ return gcf.gaussian_copula_factory(
151
+ fast_copula_array, format_negbin_parameters_with_loaders,
152
+ param_name=['mean', 'dispersion']
153
+ )
@@ -98,3 +98,27 @@ poisson_copula_array = gcf.gaussian_copula_array_factory(
98
98
  poisson_copula = gcf.gaussian_copula_factory(
99
99
  poisson_copula_array, format_poisson_parameters_with_loaders, ['mean']
100
100
  )
101
+
102
+ ###############################################################################
103
+ ## Fast copula versions for poisson regression
104
+ ###############################################################################
105
+
106
+ def fast_poisson_copula_array_factory(top_k: int):
107
+ """
108
+ top_k: int
109
+ Number of top genes to model with full covariance
110
+ """
111
+ return gcf.fast_gaussian_copula_array_factory(
112
+ poisson_regression_array, poisson_uniformizer, top_k
113
+ )
114
+
115
+ def fast_poisson_copula_factory(top_k: int):
116
+ """
117
+ top_k: int
118
+ Number of top genes to model with full covariance
119
+ """
120
+ fast_copula_array = fast_poisson_copula_array_factory(top_k)
121
+ return gcf.gaussian_copula_factory(
122
+ fast_copula_array, format_poisson_parameters_with_loaders,
123
+ param_name=['mean']
124
+ )
@@ -1,6 +1,6 @@
1
1
  from .loader import obs_loader
2
2
  from .scd3 import SCD3Simulator
3
- from .standard_covariance import StandardCovariance
3
+ from .standard_copula import StandardCopula
4
4
  from anndata import AnnData
5
5
  from typing import Dict, Optional, List
6
6
  import numpy as np
@@ -10,7 +10,7 @@ class CompositeCopula(SCD3Simulator):
10
10
  def __init__(self, marginals: List,
11
11
  copula_formula: Optional[str] = None) -> None:
12
12
  self.marginals = marginals
13
- self.copula = StandardCovariance(copula_formula)
13
+ self.copula = StandardCopula(copula_formula)
14
14
  self.template = None
15
15
  self.parameters = None
16
16
  self.merged_formula = None
@@ -0,0 +1,205 @@
1
+ from typing import Dict, Callable, Tuple
2
+ import torch
3
+ from anndata import AnnData
4
+ from .loader import adata_loader
5
+ from abc import ABC, abstractmethod
6
+ import numpy as np
7
+ import pandas as pd
8
+ from typing import Optional, Union
9
+ class Copula(ABC):
10
+ def __init__(self, formula: str, **kwargs):
11
+ self.formula = formula
12
+ self.loader = None
13
+ self.n_outcomes = None
14
+ self.parameters = None # Should be a dictionary of CovarianceStructure objects
15
+
16
+ def setup_data(self, adata: AnnData, marginal_formula: Dict[str, str], batch_size: int = 1024, **kwargs):
17
+ self.adata = adata
18
+ self.formula = self.formula | marginal_formula
19
+ self.loader = adata_loader(adata, self.formula, batch_size=batch_size, **kwargs)
20
+ X_batch, _ = next(iter(self.loader))
21
+ self.n_outcomes = X_batch.shape[1]
22
+
23
+ def decorrelate(self, row_pattern: str, col_pattern: str, group: Union[str, list, None] = None):
24
+ """Decorrelate the covariance matrix for the given row and column patterns.
25
+
26
+ Args:
27
+ row_pattern (str): The regex pattern for the row names to match.
28
+ col_pattern (str): The regex pattern for the column names to match.
29
+ group (Union[str, list, None]): The group or groups to apply the transformation to. If None, the transformation is applied to all groups.
30
+ """
31
+ if group is None:
32
+ for g in self.groups:
33
+ self.parameters[g].decorrelate(row_pattern, col_pattern)
34
+ elif isinstance(group, str):
35
+ self.parameters[group].decorrelate(row_pattern, col_pattern)
36
+ else:
37
+ for g in group:
38
+ self.parameters[g].decorrelate(row_pattern, col_pattern)
39
+
40
+ def correlate(self, factor: float, row_pattern: str, col_pattern: str, group: Union[str, list, None] = None):
41
+ """Multiply selected off-diagonal entries by factor.
42
+
43
+ Args:
44
+ row_pattern (str): The regex pattern for the row names to match.
45
+ col_pattern (str): The regex pattern for the column names to match.
46
+ factor (float): The factor to multiply the off-diagonal entries by.
47
+ group (Union[str, list, None]): The group or groups to apply the transformation to. If None, the transformation is applied to all groups.
48
+ """
49
+ if group is None:
50
+ for g in self.groups:
51
+ self.parameters[g].correlate(row_pattern, col_pattern, factor)
52
+ elif isinstance(group, str):
53
+ self.parameters[group].correlate(row_pattern, col_pattern, factor)
54
+ else:
55
+ for g in group:
56
+ self.parameters[g].correlate(row_pattern, col_pattern, factor)
57
+
58
+ @abstractmethod
59
+ def fit(self, uniformizer: Callable, **kwargs):
60
+ raise NotImplementedError
61
+
62
+ @abstractmethod
63
+ def pseudo_obs(self, x_dict: Dict):
64
+ raise NotImplementedError
65
+
66
+ @abstractmethod
67
+ def likelihood(self, uniformizer: Callable, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
68
+ raise NotImplementedError
69
+
70
+ @abstractmethod
71
+ def num_params(self, **kwargs):
72
+ raise NotImplementedError
73
+
74
+ # @abstractmethod
75
+ # def format_parameters(self):
76
+ # raise NotImplementedError
77
+
78
+ class CovarianceStructure:
79
+ """
80
+ Efficient storage for covariance matrices in copula-based gene expression modeling.
81
+
82
+ This class provides memory-efficient storage for covariance information by storing
83
+ either a full covariance matrix or a block matrix with diagonal variances for
84
+ remaining genes. This enables fast copula estimation and sampling for large
85
+ gene expression datasets.
86
+
87
+
88
+
89
+ Attributes
90
+ ----------
91
+ cov : pd.DataFrame
92
+ Covariance matrix for modeled genes with gene names as index/columns
93
+ modeled_indices : np.ndarray
94
+ Indices of modeled genes in original ordering
95
+ remaining_var : pd.Series or None
96
+ Diagonal variances for remaining genes, None if full matrix stored
97
+ remaining_indices : np.ndarray or None
98
+ Indices of remaining genes in original ordering
99
+ num_modeled_genes : int
100
+ Number of modeled genes
101
+ num_remaining_genes : int
102
+ Number of remaining genes (0 if full matrix stored)
103
+ total_genes : int
104
+ Total number of genes
105
+ """
106
+
107
+ def __init__(self, cov: np.ndarray,
108
+ modeled_names: pd.Index,
109
+ modeled_indices: Optional[np.ndarray] = None,
110
+ remaining_var: Optional[np.ndarray] = None,
111
+ remaining_indices: Optional[np.ndarray] = None,
112
+ remaining_names: Optional[pd.Index] = None):
113
+ """initialize a CovarianceStructure object.
114
+
115
+ Args:
116
+ cov (np.ndarray): Covariance matrix for modeled genes, shape (n_modeled_genes, n_modeled_genes)
117
+ modeled_names (pd.Index): Gene names for the modeled genes
118
+ modeled_indices (Optional[np.ndarray], optional): Indices of modeled genes in original ordering. Defaults to sequential indices.
119
+ remaining_var (Optional[np.ndarray], optional): Diagonal variances for remaining genes, shape (n_remaining_genes,)
120
+ remaining_indices (Optional[np.ndarray], optional): Indices of remaining genes in original ordering
121
+ remaining_names (Optional[pd.Index], optional): Gene names for remaining genes
122
+ """
123
+ self.cov = pd.DataFrame(cov, index=modeled_names, columns=modeled_names)
124
+
125
+ if modeled_indices is not None:
126
+ self.modeled_indices = modeled_indices
127
+ else:
128
+ self.modeled_indices = np.arange(len(modeled_names))
129
+
130
+ if remaining_var is not None:
131
+ self.remaining_var = pd.Series(remaining_var, index=remaining_names)
132
+ else:
133
+ self.remaining_var = None
134
+
135
+ self.remaining_indices = remaining_indices
136
+ self.num_modeled_genes = len(modeled_names)
137
+ self.num_remaining_genes = len(remaining_indices) if remaining_indices is not None else 0
138
+ self.total_genes = self.num_modeled_genes + self.num_remaining_genes
139
+
140
+ def __repr__(self):
141
+ if self.remaining_var is None:
142
+ return self.cov.__repr__()
143
+ else:
144
+ return f"CovarianceStructure(modeled_genes={self.num_modeled_genes}, \
145
+ total_genes={self.total_genes})"
146
+
147
+ def _repr_html_(self):
148
+ """Jupyter Notebook display"""
149
+ if self.remaining_var is None:
150
+ return self.cov._repr_html_()
151
+ else:
152
+ html = f"<b>CovarianceStructure:</b> {self.num_modeled_genes} modeled genes, {self.total_genes} total<br>"
153
+ html += "<h4>Modeled Covariance Matrix</h4>" + self.cov._repr_html_()
154
+ html += "<h4>Remaining Gene Variances</h4>" + self.remaining_var.to_frame("variance").T._repr_html_()
155
+ return html
156
+
157
+ def decorrelate(self, row_pattern: str, col_pattern: str):
158
+ """Decorrelate the covariance matrix for the given row and column patterns.
159
+ """
160
+ from .transform import data_frame_mask
161
+ m1 = data_frame_mask(self.cov, ".", col_pattern)
162
+ m2 = data_frame_mask(self.cov, row_pattern, ".")
163
+ mask = (m1 | m2)
164
+ np.fill_diagonal(mask, False)
165
+ self.cov.values[mask] = 0
166
+
167
+ def correlate(self, row_pattern: str, col_pattern: str, factor: float):
168
+ """Multiply selected off-diagonal entries by factor.
169
+
170
+ Args:
171
+ row_pattern (str): The regex pattern for the row names to match.
172
+ col_pattern (str): The regex pattern for the column names to match.
173
+ factor (float): The factor to multiply the off-diagonal entries by.
174
+ """
175
+ from .transform import data_frame_mask
176
+ m1 = data_frame_mask(self.cov, ".", col_pattern)
177
+ m2 = data_frame_mask(self.cov, row_pattern, ".")
178
+ mask = (m1 | m2)
179
+ np.fill_diagonal(mask, False)
180
+ self.cov.values[mask] = self.cov.values[mask] * factor
181
+
182
+ @property
183
+ def shape(self):
184
+ return (self.total_genes, self.total_genes)
185
+
186
+ def to_full_matrix(self):
187
+ """
188
+ Convert to full covariance matrix for compatibility/debugging.
189
+ Returns:
190
+ --------
191
+ np.ndarray : Full covariance matrix with shape (total_genes, total_genes)
192
+ """
193
+ if self.remaining_var is None:
194
+ return self.cov.values
195
+ else:
196
+ full_cov = np.zeros((self.total_genes, self.total_genes))
197
+
198
+ # Fill in top-k block
199
+ ix_modeled = np.ix_(self.modeled_indices, self.modeled_indices)
200
+ full_cov[ix_modeled] = self.cov.values
201
+
202
+ # Fill in diagonal for remaining genes
203
+ full_cov[self.remaining_indices, self.remaining_indices] = self.remaining_var.values
204
+
205
+ return full_cov