scdesigner 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scdesigner might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
- from .negbin import negbin_regression, negbin_copula
2
- from .gaussian_copula_factory import group_indices
3
- from .poisson import poisson_regression, poisson_copula
1
+ from .negbin import negbin_regression, negbin_copula, fast_negbin_copula_factory
2
+ from .gaussian_copula_factory import group_indices, fast_copula_covariance, FastCovarianceStructure, fast_gaussian_copula_array_factory
3
+ from .poisson import poisson_regression, poisson_copula, fast_poisson_copula_factory
4
4
  from .bernoulli import bernoulli_regression, bernoulli_copula
5
5
  from .gaussian import gaussian_regression, gaussian_copula
6
6
  from .zero_inflated_negbin import (
@@ -24,4 +24,9 @@ __all__ = [
24
24
  "zero_inflated_negbin_regression",
25
25
  "zero_inflated_poisson_regression",
26
26
  "multiple_formula_regression_factory",
27
+ "fast_copula_covariance",
28
+ "FastCovarianceStructure",
29
+ "fast_gaussian_copula_array_factory",
30
+ "fast_negbin_copula_factory",
31
+ "fast_poisson_copula_factory",
27
32
  ]
@@ -31,6 +31,28 @@ def gaussian_copula_array_factory(marginal_model: Callable, uniformizer: Callabl
31
31
  return copula_fun
32
32
 
33
33
 
34
+ def fast_gaussian_copula_array_factory(marginal_model: Callable, uniformizer: Callable, top_k: int):
35
+ """
36
+ Factory function for fast Gaussian copula array computation using top-k gene modeling.
37
+
38
+ """
39
+ def copula_fun(loaders: dict[str, DataLoader], lr: float = 0.1, epochs: int = 40, **kwargs):
40
+ # for the marginal model, ignore the groupings
41
+ # Strip all dataloaders and create a dictionary to pass to marginal_model
42
+ formula_loaders = {}
43
+ for key in loaders.keys():
44
+ formula_loaders[key] = strip_dataloader(loaders[key], pop="Stack" in type(loaders[key].dataset).__name__)
45
+
46
+ # Call marginal_model with the dictionary of stripped dataloaders
47
+ parameters = marginal_model(formula_loaders, lr=lr, epochs=epochs, **kwargs)
48
+
49
+ # estimate covariance using fast method, allowing for different groups
50
+ parameters["covariance"] = fast_copula_covariance(parameters, loaders, uniformizer, top_k)
51
+ return parameters
52
+
53
+ return copula_fun
54
+
55
+
34
56
  def gaussian_copula_factory(copula_array_fun: Callable,
35
57
  parameter_formatter: Callable,
36
58
  param_name: list = None):
@@ -65,7 +87,10 @@ def gaussian_copula_factory(copula_array_fun: Callable,
65
87
  return copula_fun
66
88
 
67
89
 
90
+
91
+
68
92
  def copula_covariance(parameters: dict, loaders: dict[str, DataLoader], uniformizer: Callable):
93
+
69
94
  first_loader = next(iter(loaders.values()))
70
95
  D = next(iter(first_loader))[1].shape[1] #dimension of y
71
96
  groups = first_loader.dataset.groups # a list of strings of group names
@@ -98,7 +123,183 @@ def copula_covariance(parameters: dict, loaders: dict[str, DataLoader], uniformi
98
123
 
99
124
  if len(groups) == 1:
100
125
  return list(result.values())[0]
101
- return result
126
+ return result
127
+
128
+
129
+ def fast_copula_covariance(parameters: dict, loaders: dict[str, DataLoader], uniformizer: Callable, top_k: int):
130
+ """
131
+ Compute an efficient approximation of copula covariance by modeling only the top-k most prevalent genes
132
+ with full covariance and approximating the rest with diagonal covariance.
133
+
134
+ Parameters:
135
+ -----------
136
+ parameters : dict
137
+ Model parameters dictionary
138
+ loaders : dict[str, DataLoader]
139
+ Dictionary of data loaders
140
+ uniformizer : Callable
141
+ Function to convert to uniform distribution
142
+ top_k : int
143
+ Number of top genes to model with full covariance
144
+
145
+ Returns:
146
+ --------
147
+ dict or FastCovarianceStructure:
148
+ - If single group: FastCovarianceStructure containing:
149
+ * top_k_cov: (top_k, top_k) full covariance matrix for top genes
150
+ * remaining_var: (remaining_genes,) diagonal variances for remaining genes
151
+ * top_k_indices: indices of top-k genes
152
+ * remaining_indices: indices of remaining genes
153
+ * gene_total_expression: total expression levels for gene selection
154
+ - If multiple groups: dict mapping group names to FastCovarianceStructure objects
155
+ """
156
+
157
+ first_loader = next(iter(loaders.values()))
158
+ D = next(iter(first_loader))[1].shape[1] #dimension of y
159
+ groups = first_loader.dataset.groups # a list of strings of group names
160
+
161
+ # Validate top_k parameter
162
+ if top_k <= 0:
163
+ raise ValueError("top_k must be a positive integer")
164
+ if top_k >= D:
165
+ # If top_k is larger than total genes, fall back to regular covariance
166
+ return copula_covariance(parameters, loaders, uniformizer)
167
+
168
+ # Step 1: Calculate total expression for each gene to determine prevalence
169
+ gene_total_expression = np.zeros(D)
170
+
171
+ keys = list(loaders.keys())
172
+ loaders_list = list(loaders.values())
173
+ num_keys = len(keys)
174
+
175
+ # Calculate total expression across all batches
176
+ for batches in zip(*loaders_list):
177
+ y_batch = batches[0][1].cpu().numpy()
178
+ gene_total_expression += y_batch.sum(axis=0)
179
+
180
+ # Step 2: Select top-k most prevalent genes
181
+ top_k_indices = np.argsort(gene_total_expression)[-top_k:]
182
+ remaining_indices = np.argsort(gene_total_expression)[:-top_k]
183
+
184
+ # Step 3: Compute statistics for both top-k and remaining genes
185
+ sums_top_k = {g: np.zeros(top_k) for g in groups}
186
+ second_moments_top_k = {g: np.zeros((top_k, top_k)) for g in groups}
187
+
188
+ sums_remaining = {g: np.zeros(len(remaining_indices)) for g in groups}
189
+ second_moments_remaining = {g: np.zeros(len(remaining_indices)) for g in groups}
190
+
191
+ Ng = {g: 0 for g in groups}
192
+
193
+ # Reset loaders for second pass
194
+ loaders_list = list(loaders.values())
195
+
196
+ for batches in zip(*loaders_list):
197
+ x_batch_dict = {
198
+ keys[i]: batches[i][0].cpu().numpy() for i in range(num_keys)
199
+ }
200
+ y_batch = batches[0][1].cpu().numpy()
201
+ memberships = batches[0][2] # should be identical for all keys
202
+
203
+ u = uniformizer(parameters, x_batch_dict, y_batch)
204
+
205
+ for g in groups:
206
+ ix = np.where(np.array(memberships) == g)
207
+ if len(ix[0]) == 0:
208
+ continue
209
+
210
+ z = norm().ppf(u[ix])
211
+
212
+ # Process top-k genes with full covariance
213
+ z_top_k = z[:, top_k_indices]
214
+ second_moments_top_k[g] += z_top_k.T @ z_top_k
215
+ sums_top_k[g] += z_top_k.sum(axis=0)
216
+
217
+ # Process remaining genes with diagonal covariance only
218
+ z_remaining = z[:, remaining_indices]
219
+ second_moments_remaining[g] += (z_remaining ** 2).sum(axis=0)
220
+ sums_remaining[g] += z_remaining.sum(axis=0)
221
+
222
+ Ng[g] += len(ix[0])
223
+
224
+ # Step 4: Compute final covariance structures
225
+ result = {}
226
+ for g in groups:
227
+ if Ng[g] == 0:
228
+ continue
229
+
230
+ # Full covariance for top-k genes
231
+ mean_top_k = sums_top_k[g] / Ng[g]
232
+ cov_top_k = second_moments_top_k[g] / Ng[g] - np.outer(mean_top_k, mean_top_k)
233
+
234
+ # Diagonal variance for remaining genes
235
+ mean_remaining = sums_remaining[g] / Ng[g]
236
+ var_remaining = second_moments_remaining[g] / Ng[g] - mean_remaining ** 2
237
+
238
+ # Create FastCovarianceStructure
239
+ result[g] = FastCovarianceStructure(
240
+ top_k_cov=cov_top_k,
241
+ remaining_var=var_remaining,
242
+ top_k_indices=top_k_indices,
243
+ remaining_indices=remaining_indices,
244
+ gene_total_expression=gene_total_expression
245
+ )
246
+
247
+ if len(groups) == 1:
248
+ return list(result.values())[0]
249
+ return result
250
+
251
+
252
+ class FastCovarianceStructure:
253
+ """
254
+ Data structure to efficiently store and access covariance information for fast copula sampling.
255
+
256
+ Attributes:
257
+ -----------
258
+ top_k_cov : np.ndarray
259
+ Full covariance matrix for top-k most prevalent genes, shape (top_k, top_k)
260
+ remaining_var : np.ndarray
261
+ Diagonal variances for remaining genes, shape (remaining_genes,)
262
+ top_k_indices : np.ndarray
263
+ Indices of the top-k genes in the original gene ordering
264
+ remaining_indices : np.ndarray
265
+ Indices of the remaining genes in the original gene ordering
266
+ gene_total_expression : np.ndarray
267
+ Total expression levels used for gene selection, shape (total_genes,)
268
+ """
269
+
270
+ def __init__(self, top_k_cov, remaining_var, top_k_indices, remaining_indices, gene_total_expression):
271
+ self.top_k_cov = top_k_cov
272
+ self.remaining_var = remaining_var
273
+ self.top_k_indices = top_k_indices
274
+ self.remaining_indices = remaining_indices
275
+ self.gene_total_expression = gene_total_expression
276
+ self.top_k = len(top_k_indices)
277
+ self.total_genes = len(top_k_indices) + len(remaining_indices)
278
+
279
+ def __repr__(self):
280
+ return (f"FastCovarianceStructure(top_k={self.top_k}, "
281
+ f"remaining_genes={len(self.remaining_indices)}, "
282
+ f"total_genes={self.total_genes})")
283
+
284
+ def to_full_matrix(self):
285
+ """
286
+ Convert to full covariance matrix for compatibility/debugging.
287
+
288
+ Returns:
289
+ --------
290
+ np.ndarray : Full covariance matrix with shape (total_genes, total_genes)
291
+ """
292
+ full_cov = np.zeros((self.total_genes, self.total_genes))
293
+
294
+ # Fill in top-k block
295
+ ix_top = np.ix_(self.top_k_indices, self.top_k_indices)
296
+ full_cov[ix_top] = self.top_k_cov
297
+
298
+ # Fill in diagonal for remaining genes
299
+ full_cov[self.remaining_indices, self.remaining_indices] = self.remaining_var
300
+
301
+ return full_cov
302
+
102
303
 
103
304
 
104
305
  ###############################################################################
@@ -128,18 +329,32 @@ def clip(u: np.array, min: float = 1e-5, max: float = 1 - 1e-5) -> np.array:
128
329
 
129
330
 
130
331
  def format_copula_parameters(parameters: dict, var_names: list):
332
+ '''
333
+ Format the copula parameters into a dictionary of covariance matrices in pandas dataframe format.
334
+ If the covariance is a FastCovarianceStructure, return it as is.
335
+ If the covariance is a dictionary of FastCovarianceStructure objects, return it as is.
336
+ Otherwise, return a dictionary of covariance matrices in pandas dataframe format.
337
+ '''
131
338
  covariance = parameters["covariance"]
132
- if type(covariance) is not dict:
339
+
340
+ # Handle FastCovarianceStructure - keep it as is since it has efficient methods
341
+ if isinstance(covariance, FastCovarianceStructure):
342
+ return covariance
343
+ elif isinstance(covariance, dict) and any(isinstance(v, FastCovarianceStructure) for v in covariance.values()):
344
+ # If it's a dict containing FastCovarianceStructure objects, keep as is
345
+ return covariance
346
+ elif type(covariance) is not dict:
133
347
  covariance = pd.DataFrame(
134
348
  parameters["covariance"], columns=list(var_names), index=list(var_names)
135
349
  )
136
350
  else:
137
351
  for group in covariance.keys():
138
- covariance[group] = pd.DataFrame(
139
- parameters["covariance"][group],
140
- columns=list(var_names),
141
- index=list(var_names),
142
- )
352
+ if not isinstance(covariance[group], FastCovarianceStructure):
353
+ covariance[group] = pd.DataFrame(
354
+ parameters["covariance"][group],
355
+ columns=list(var_names),
356
+ index=list(var_names),
357
+ )
143
358
  return covariance
144
359
 
145
360
 
@@ -127,3 +127,27 @@ negbin_copula = gcf.gaussian_copula_factory(
127
127
  negbin_copula_array, format_negbin_parameters_with_loaders,
128
128
  param_name=['mean', 'dispersion']
129
129
  )
130
+
131
+ ###############################################################################
132
+ ## Fast copula versions for negative binomial regression
133
+ ###############################################################################
134
+
135
+ def fast_negbin_copula_array_factory(top_k: int):
136
+ """
137
+ top_k: int
138
+ Number of top genes to model with full covariance
139
+ """
140
+ return gcf.fast_gaussian_copula_array_factory(
141
+ negbin_regression_array, negbin_uniformizer, top_k
142
+ )
143
+
144
+ def fast_negbin_copula_factory(top_k: int):
145
+ """
146
+ top_k: int
147
+ Number of top genes to model with full covariance
148
+ """
149
+ fast_copula_array = fast_negbin_copula_array_factory(top_k)
150
+ return gcf.gaussian_copula_factory(
151
+ fast_copula_array, format_negbin_parameters_with_loaders,
152
+ param_name=['mean', 'dispersion']
153
+ )
@@ -98,3 +98,27 @@ poisson_copula_array = gcf.gaussian_copula_array_factory(
98
98
  poisson_copula = gcf.gaussian_copula_factory(
99
99
  poisson_copula_array, format_poisson_parameters_with_loaders, ['mean']
100
100
  )
101
+
102
+ ###############################################################################
103
+ ## Fast copula versions for poisson regression
104
+ ###############################################################################
105
+
106
+ def fast_poisson_copula_array_factory(top_k: int):
107
+ """
108
+ top_k: int
109
+ Number of top genes to model with full covariance
110
+ """
111
+ return gcf.fast_gaussian_copula_array_factory(
112
+ poisson_regression_array, poisson_uniformizer, top_k
113
+ )
114
+
115
+ def fast_poisson_copula_factory(top_k: int):
116
+ """
117
+ top_k: int
118
+ Number of top genes to model with full covariance
119
+ """
120
+ fast_copula_array = fast_poisson_copula_array_factory(top_k)
121
+ return gcf.gaussian_copula_factory(
122
+ fast_copula_array, format_poisson_parameters_with_loaders,
123
+ param_name=['mean']
124
+ )
@@ -1,6 +1,6 @@
1
1
  from .loader import obs_loader
2
2
  from .scd3 import SCD3Simulator
3
- from .standard_covariance import StandardCovariance
3
+ from .standard_copula import StandardCopula
4
4
  from anndata import AnnData
5
5
  from typing import Dict, Optional, List
6
6
  import numpy as np
@@ -10,7 +10,7 @@ class CompositeCopula(SCD3Simulator):
10
10
  def __init__(self, marginals: List,
11
11
  copula_formula: Optional[str] = None) -> None:
12
12
  self.marginals = marginals
13
- self.copula = StandardCovariance(copula_formula)
13
+ self.copula = StandardCopula(copula_formula)
14
14
  self.template = None
15
15
  self.parameters = None
16
16
  self.merged_formula = None
@@ -2,13 +2,16 @@ from typing import Dict, Callable, Tuple
2
2
  import torch
3
3
  from anndata import AnnData
4
4
  from .loader import adata_loader
5
-
6
- class Copula:
5
+ from abc import ABC, abstractmethod
6
+ import numpy as np
7
+ import pandas as pd
8
+ from typing import Optional, Union
9
+ class Copula(ABC):
7
10
  def __init__(self, formula: str, **kwargs):
8
11
  self.formula = formula
9
12
  self.loader = None
10
13
  self.n_outcomes = None
11
- self.parameters = None
14
+ self.parameters = None # Should be a dictionary of CovarianceStructure objects
12
15
 
13
16
  def setup_data(self, adata: AnnData, marginal_formula: Dict[str, str], batch_size: int = 1024, **kwargs):
14
17
  self.adata = adata
@@ -16,18 +19,187 @@ class Copula:
16
19
  self.loader = adata_loader(adata, self.formula, batch_size=batch_size, **kwargs)
17
20
  X_batch, _ = next(iter(self.loader))
18
21
  self.n_outcomes = X_batch.shape[1]
19
-
22
+
23
+ def decorrelate(self, row_pattern: str, col_pattern: str, group: Union[str, list, None] = None):
24
+ """Decorrelate the covariance matrix for the given row and column patterns.
25
+
26
+ Args:
27
+ row_pattern (str): The regex pattern for the row names to match.
28
+ col_pattern (str): The regex pattern for the column names to match.
29
+ group (Union[str, list, None]): The group or groups to apply the transformation to. If None, the transformation is applied to all groups.
30
+ """
31
+ if group is None:
32
+ for g in self.groups:
33
+ self.parameters[g].decorrelate(row_pattern, col_pattern)
34
+ elif isinstance(group, str):
35
+ self.parameters[group].decorrelate(row_pattern, col_pattern)
36
+ else:
37
+ for g in group:
38
+ self.parameters[g].decorrelate(row_pattern, col_pattern)
39
+
40
+ def correlate(self, factor: float, row_pattern: str, col_pattern: str, group: Union[str, list, None] = None):
41
+ """Multiply selected off-diagonal entries by factor.
42
+
43
+ Args:
44
+ row_pattern (str): The regex pattern for the row names to match.
45
+ col_pattern (str): The regex pattern for the column names to match.
46
+ factor (float): The factor to multiply the off-diagonal entries by.
47
+ group (Union[str, list, None]): The group or groups to apply the transformation to. If None, the transformation is applied to all groups.
48
+ """
49
+ if group is None:
50
+ for g in self.groups:
51
+ self.parameters[g].correlate(row_pattern, col_pattern, factor)
52
+ elif isinstance(group, str):
53
+ self.parameters[group].correlate(row_pattern, col_pattern, factor)
54
+ else:
55
+ for g in group:
56
+ self.parameters[g].correlate(row_pattern, col_pattern, factor)
57
+
58
+ @abstractmethod
20
59
  def fit(self, uniformizer: Callable, **kwargs):
21
60
  raise NotImplementedError
22
61
 
62
+ @abstractmethod
23
63
  def pseudo_obs(self, x_dict: Dict):
24
64
  raise NotImplementedError
25
65
 
66
+ @abstractmethod
26
67
  def likelihood(self, uniformizer: Callable, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
27
68
  raise NotImplementedError
28
69
 
70
+ @abstractmethod
29
71
  def num_params(self, **kwargs):
30
72
  raise NotImplementedError
31
73
 
32
- def format_parameters(self):
33
- raise NotImplementedError
74
+ # @abstractmethod
75
+ # def format_parameters(self):
76
+ # raise NotImplementedError
77
+
78
+ class CovarianceStructure:
79
+ """
80
+ Efficient storage for covariance matrices in copula-based gene expression modeling.
81
+
82
+ This class provides memory-efficient storage for covariance information by storing
83
+ either a full covariance matrix or a block matrix with diagonal variances for
84
+ remaining genes. This enables fast copula estimation and sampling for large
85
+ gene expression datasets.
86
+
87
+
88
+
89
+ Attributes
90
+ ----------
91
+ cov : pd.DataFrame
92
+ Covariance matrix for modeled genes with gene names as index/columns
93
+ modeled_indices : np.ndarray
94
+ Indices of modeled genes in original ordering
95
+ remaining_var : pd.Series or None
96
+ Diagonal variances for remaining genes, None if full matrix stored
97
+ remaining_indices : np.ndarray or None
98
+ Indices of remaining genes in original ordering
99
+ num_modeled_genes : int
100
+ Number of modeled genes
101
+ num_remaining_genes : int
102
+ Number of remaining genes (0 if full matrix stored)
103
+ total_genes : int
104
+ Total number of genes
105
+ """
106
+
107
+ def __init__(self, cov: np.ndarray,
108
+ modeled_names: pd.Index,
109
+ modeled_indices: Optional[np.ndarray] = None,
110
+ remaining_var: Optional[np.ndarray] = None,
111
+ remaining_indices: Optional[np.ndarray] = None,
112
+ remaining_names: Optional[pd.Index] = None):
113
+ """initialize a CovarianceStructure object.
114
+
115
+ Args:
116
+ cov (np.ndarray): Covariance matrix for modeled genes, shape (n_modeled_genes, n_modeled_genes)
117
+ modeled_names (pd.Index): Gene names for the modeled genes
118
+ modeled_indices (Optional[np.ndarray], optional): Indices of modeled genes in original ordering. Defaults to sequential indices.
119
+ remaining_var (Optional[np.ndarray], optional): Diagonal variances for remaining genes, shape (n_remaining_genes,)
120
+ remaining_indices (Optional[np.ndarray], optional): Indices of remaining genes in original ordering
121
+ remaining_names (Optional[pd.Index], optional): Gene names for remaining genes
122
+ """
123
+ self.cov = pd.DataFrame(cov, index=modeled_names, columns=modeled_names)
124
+
125
+ if modeled_indices is not None:
126
+ self.modeled_indices = modeled_indices
127
+ else:
128
+ self.modeled_indices = np.arange(len(modeled_names))
129
+
130
+ if remaining_var is not None:
131
+ self.remaining_var = pd.Series(remaining_var, index=remaining_names)
132
+ else:
133
+ self.remaining_var = None
134
+
135
+ self.remaining_indices = remaining_indices
136
+ self.num_modeled_genes = len(modeled_names)
137
+ self.num_remaining_genes = len(remaining_indices) if remaining_indices is not None else 0
138
+ self.total_genes = self.num_modeled_genes + self.num_remaining_genes
139
+
140
+ def __repr__(self):
141
+ if self.remaining_var is None:
142
+ return self.cov.__repr__()
143
+ else:
144
+ return f"CovarianceStructure(modeled_genes={self.num_modeled_genes}, \
145
+ total_genes={self.total_genes})"
146
+
147
+ def _repr_html_(self):
148
+ """Jupyter Notebook display"""
149
+ if self.remaining_var is None:
150
+ return self.cov._repr_html_()
151
+ else:
152
+ html = f"<b>CovarianceStructure:</b> {self.num_modeled_genes} modeled genes, {self.total_genes} total<br>"
153
+ html += "<h4>Modeled Covariance Matrix</h4>" + self.cov._repr_html_()
154
+ html += "<h4>Remaining Gene Variances</h4>" + self.remaining_var.to_frame("variance").T._repr_html_()
155
+ return html
156
+
157
+ def decorrelate(self, row_pattern: str, col_pattern: str):
158
+ """Decorrelate the covariance matrix for the given row and column patterns.
159
+ """
160
+ from .transform import data_frame_mask
161
+ m1 = data_frame_mask(self.cov, ".", col_pattern)
162
+ m2 = data_frame_mask(self.cov, row_pattern, ".")
163
+ mask = (m1 | m2)
164
+ np.fill_diagonal(mask, False)
165
+ self.cov.values[mask] = 0
166
+
167
+ def correlate(self, row_pattern: str, col_pattern: str, factor: float):
168
+ """Multiply selected off-diagonal entries by factor.
169
+
170
+ Args:
171
+ row_pattern (str): The regex pattern for the row names to match.
172
+ col_pattern (str): The regex pattern for the column names to match.
173
+ factor (float): The factor to multiply the off-diagonal entries by.
174
+ """
175
+ from .transform import data_frame_mask
176
+ m1 = data_frame_mask(self.cov, ".", col_pattern)
177
+ m2 = data_frame_mask(self.cov, row_pattern, ".")
178
+ mask = (m1 | m2)
179
+ np.fill_diagonal(mask, False)
180
+ self.cov.values[mask] = self.cov.values[mask] * factor
181
+
182
+ @property
183
+ def shape(self):
184
+ return (self.total_genes, self.total_genes)
185
+
186
+ def to_full_matrix(self):
187
+ """
188
+ Convert to full covariance matrix for compatibility/debugging.
189
+ Returns:
190
+ --------
191
+ np.ndarray : Full covariance matrix with shape (total_genes, total_genes)
192
+ """
193
+ if self.remaining_var is None:
194
+ return self.cov.values
195
+ else:
196
+ full_cov = np.zeros((self.total_genes, self.total_genes))
197
+
198
+ # Fill in top-k block
199
+ ix_modeled = np.ix_(self.modeled_indices, self.modeled_indices)
200
+ full_cov[ix_modeled] = self.cov.values
201
+
202
+ # Fill in diagonal for remaining genes
203
+ full_cov[self.remaining_indices, self.remaining_indices] = self.remaining_var.values
204
+
205
+ return full_cov
@@ -6,9 +6,10 @@ import pandas as pd
6
6
  import pytorch_lightning as pl
7
7
  import torch
8
8
  import torch.nn as nn
9
+ from abc import ABC, abstractmethod
9
10
 
10
11
 
11
- class Marginal:
12
+ class Marginal(ABC):
12
13
  def __init__(self, formula: Union[Dict, str]):
13
14
  self.formula = formula
14
15
  self.feature_dims = None
@@ -37,23 +38,6 @@ class Marginal:
37
38
  trainer.fit(self.predict, train_dataloaders=self.loader)
38
39
  self.parameters = self.format_parameters()
39
40
 
40
- def setup_optimizer(self, **kwargs):
41
- raise NotImplementedError
42
-
43
- def likelihood(self, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
44
- """Compute the (negative) log-likelihood or loss for a batch.
45
- """
46
- raise NotImplementedError
47
-
48
- def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
49
- """Invert pseudoobservations."""
50
- raise NotImplementedError
51
-
52
- def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]):
53
- """Uniformize using learned CDF.
54
- """
55
- raise NotImplementedError
56
-
57
41
  def format_parameters(self):
58
42
  """Convert fitted coefficient tensors into pandas DataFrames.
59
43
 
@@ -79,6 +63,27 @@ class Marginal:
79
63
  return 0
80
64
  return sum(p.numel() for p in self.predict.parameters() if p.requires_grad)
81
65
 
66
+ @abstractmethod
67
+ def setup_optimizer(self, **kwargs):
68
+ raise NotImplementedError
69
+
70
+ @abstractmethod
71
+ def likelihood(self, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
72
+ """Compute the (negative) log-likelihood or loss for a batch.
73
+ """
74
+ raise NotImplementedError
75
+
76
+ @abstractmethod
77
+ def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
78
+ """Invert pseudoobservations."""
79
+ raise NotImplementedError
80
+
81
+ @abstractmethod
82
+ def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]):
83
+ """Uniformize using learned CDF.
84
+ """
85
+ raise NotImplementedError
86
+
82
87
 
83
88
  class GLMPredictor(pl.LightningModule):
84
89
  """GLM-style predictor with arbitrary named parameters.
@@ -136,5 +141,5 @@ class GLMPredictor(pl.LightningModule):
136
141
  return loss
137
142
 
138
143
  def configure_optimizers(self, **kwargs):
139
- optimizer_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['optimizer'])
144
+ optimizer_kwargs = _filter_kwargs(self.optimizer_kwargs, DEFAULT_ALLOWED_KWARGS['optimizer'])
140
145
  return self.optimizer_class(self.parameters(), **optimizer_kwargs)
@@ -30,7 +30,7 @@ class NegBin(Marginal):
30
30
  )
31
31
 
32
32
  def likelihood(self, batch):
33
- """Compute the negative log-likelihood"""
33
+ """Compute the log-likelihood"""
34
34
  y, x = batch
35
35
  params = self.predict(x)
36
36
  mu = params.get('mean')
@@ -6,6 +6,7 @@ from anndata import AnnData
6
6
  from tqdm import tqdm
7
7
  import torch
8
8
  import numpy as np
9
+ from abc import ABC, abstractmethod
9
10
 
10
11
  class SCD3Simulator(Simulator):
11
12
  """Simulation wrapper"""