scdesigner 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,383 @@
1
+ from .copula import Copula
2
+ from .formula import standardize_formula
3
+ from .kwargs import DEFAULT_ALLOWED_KWARGS, _filter_kwargs
4
+ from anndata import AnnData
5
+ from scipy.stats import norm, multivariate_normal
6
+ from tqdm import tqdm
7
+ from typing import Dict, Union, Callable, Tuple
8
+ import numpy as np
9
+ import torch
10
+ from .copula import CovarianceStructure
11
+ import warnings
12
+
13
+ class StandardCopula(Copula):
14
+ """Standard Gaussian Copula Model"""
15
+ def __init__(self, formula: str = "~ 1"):
16
+ """Initialize the StandardCopula model.
17
+
18
+ Args:
19
+ formula (str, optional): _description_. Defaults to "~ 1".
20
+ """
21
+ formula = standardize_formula(formula, allowed_keys=['group'])
22
+ super().__init__(formula)
23
+ self.groups = None
24
+
25
+
26
+ def setup_data(self, adata: AnnData, marginal_formula: Dict[str, str], **kwargs):
27
+ """Set up the data for the standard covariance model. After setting up the data, x_dict will always have a "group" key.
28
+
29
+ Args:
30
+ adata (AnnData): The AnnData object containing the data.
31
+ marginal_formula (Dict[str, str]): The formula for the marginal model.
32
+ Raises:
33
+ ValueError: If the groupings are not binary.
34
+ """
35
+ data_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['data'])
36
+ super().setup_data(adata, marginal_formula, **data_kwargs)
37
+ _, obs_batch = next(iter(self.loader))
38
+ obs_batch_group = obs_batch.get("group")
39
+
40
+ # fill in group indexing variables
41
+ self.groups = self.loader.dataset.predictor_names["group"]
42
+ self.n_groups = len(self.groups)
43
+ self.group_col = {g: i for i, g in enumerate(self.groups)}
44
+
45
+ # check that obs_batch is a binary grouping matrix (only if group exists)
46
+ if obs_batch_group is not None:
47
+ unique_vals = torch.unique(obs_batch_group)
48
+ if (not torch.all((unique_vals == 0) | (unique_vals == 1)).item()):
49
+ raise ValueError("Only categorical groups are currently supported in copula covariance estimation.")
50
+
51
+ def fit(self, uniformizer: Callable, **kwargs):
52
+ """
53
+ Fit the copula covariance model.
54
+
55
+ Args:
56
+ uniformizer (Callable): Function to convert data to uniform distribution
57
+ **kwargs: Additional arguments
58
+ top_k (int, optional): Use only top-k most expressed genes for covariance estimation.
59
+ If None, estimates full covariance for all genes.
60
+
61
+ Returns:
62
+ None: Stores fitted parameters in self.parameters as dict of CovarianceStructure objects.
63
+
64
+ Raises:
65
+ ValueError: If top_k is not a positive integer or exceeds n_outcomes
66
+ """
67
+ top_k = kwargs.get("top_k", None)
68
+ if top_k is not None:
69
+ if not isinstance(top_k, int):
70
+ raise ValueError("top_k must be an integer")
71
+ if top_k <= 0:
72
+ raise ValueError("top_k must be positive")
73
+ if top_k > self.n_outcomes:
74
+ raise ValueError(f"top_k ({top_k}) cannot exceed number of outcomes ({self.n_outcomes})")
75
+ gene_total_expression = np.array(self.adata.X.sum(axis=0)).flatten()
76
+ sorted_indices = np.argsort(gene_total_expression)
77
+ top_k_indices = sorted_indices[-top_k:]
78
+ remaining_indices = sorted_indices[:-top_k]
79
+ covariances = self._compute_block_covariance(uniformizer, top_k_indices,
80
+ remaining_indices, top_k)
81
+ else:
82
+ covariances = self._compute_full_covariance(uniformizer)
83
+
84
+ self.parameters = covariances
85
+
86
+ def pseudo_obs(self, x_dict: Dict):
87
+ # convert one-hot encoding memberships to a map
88
+ # {"group1": [indices of group 1], "group2": [indices of group 2]}
89
+ # The initialization method ensures that x_dict will always have a "group" key.
90
+ group_data = x_dict.get("group")
91
+ memberships = group_data.cpu().numpy()
92
+ group_ix = {g: np.where(memberships[:, self.group_col[g] == 1])[0] for g in self.groups}
93
+
94
+ # initialize the result
95
+ u = np.zeros((len(memberships), self.n_outcomes))
96
+ parameters = self.parameters
97
+
98
+ # loop over groups and sample each part in turn
99
+ for group, cov_struct in parameters.items():
100
+ if cov_struct.remaining_var is not None:
101
+ u[group_ix[group]] = self._fast_normal_pseudo_obs(len(group_ix[group]), cov_struct)
102
+ else:
103
+ u[group_ix[group]] = self._normal_pseudo_obs(len(group_ix[group]), cov_struct)
104
+ return u
105
+
106
+ def likelihood(self, uniformizer: Callable, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
107
+ """
108
+ Compute likelihood of data given the copula model.
109
+
110
+ Args:
111
+ uniformizer (Callable): Function to convert expression data to uniform distribution
112
+ batch (Tuple[torch.Tensor, Dict[str, torch.Tensor]]): Data batch containing:
113
+ - Y (torch.Tensor): Expression data of shape (n_cells, n_genes)
114
+ - X_dict (Dict[str, torch.Tensor]): Covariates dict with keys as parameter names
115
+ and values as tensors of shape (n_cells, n_covariates)
116
+
117
+ Returns:
118
+ np.ndarray: Log-likelihood for each cell, shape (n_cells,)
119
+ """
120
+ # uniformize the observations
121
+ y, x_dict = batch
122
+ u = uniformizer(y, x_dict)
123
+ z = norm().ppf(u)
124
+
125
+ # same group manipulation as for pseudobs
126
+ parameters = self.parameters
127
+ if type(parameters) is not dict:
128
+ parameters = {self.groups[0]: parameters}
129
+
130
+ group_data = x_dict.get("group")
131
+ memberships = group_data.numpy()
132
+ group_ix = {g: np.where(memberships[:, self.group_col[g] == 1])[0] for g in self.groups}
133
+
134
+ ll = np.zeros(len(z))
135
+
136
+ for group, cov_struct in parameters.items():
137
+ ix = group_ix[group]
138
+ if len(ix) > 0:
139
+ z_modeled = z[ix][:, cov_struct.modeled_indices]
140
+
141
+ ll_modeled = multivariate_normal.logpdf(z_modeled,
142
+ np.zeros(cov_struct.num_modeled_genes),
143
+ cov_struct.cov.values)
144
+ if cov_struct.num_remaining_genes > 0:
145
+ z_remaining = z[ix][:, cov_struct.remaining_indices]
146
+ ll_remaining = norm.logpdf(z_remaining,
147
+ loc=0,
148
+ scale = np.sqrt(cov_struct.remaining_var.values))
149
+ else:
150
+ ll_remaining = 0
151
+ ll[ix] = ll_modeled + ll_remaining
152
+ return ll
153
+
154
+ def num_params(self, **kwargs):
155
+ S = self.parameters
156
+ per_group = [((S[g].num_modeled_genes * (S[g].num_modeled_genes - 1)) / 2) for g in self.groups]
157
+ return sum(per_group)
158
+
159
+ def _validate_parameters(self, **kwargs):
160
+ top_k = kwargs.get("top_k", None)
161
+ if top_k is not None:
162
+ if not isinstance(top_k, int):
163
+ raise ValueError("top_k must be an integer")
164
+ if top_k <= 0:
165
+ raise ValueError("top_k must be positive")
166
+ if top_k > self.n_outcomes:
167
+ raise ValueError(f"top_k ({top_k}) cannot exceed number of outcomes ({self.n_outcomes})")
168
+ return top_k
169
+
170
+
171
+
172
+ def _accumulate_top_k_stats(self, uniformizer:Callable, top_k_idx, rem_idx, top_k) \
173
+ -> Tuple[Dict[Union[str, int], np.ndarray],
174
+ Dict[Union[str, int], np.ndarray],
175
+ Dict[Union[str, int], np.ndarray],
176
+ Dict[Union[str, int], np.ndarray],
177
+ Dict[Union[str, int], int]]:
178
+ """Accumulate sufficient statistics for top-k covariance estimation.
179
+
180
+ Args:
181
+ uniformizer (Callable): Function to convert to uniform distribution
182
+ top_k_idx (np.ndarray): Indices of the top-k genes
183
+ rem_idx (np.ndarray): Indices of the remaining genes
184
+ top_k (int): Number of top-k genes
185
+
186
+ Returns:
187
+ top_k_sums (dict): Sums of the top-k genes for each group
188
+ top_k_second_moments (dict): Second moments of the top-k genes for each group
189
+ rem_sums (dict): Sums of the remaining genes for each group
190
+ rem_second_moments (dict): Second moments of the remaining genes for each group
191
+ Ng (dict): Number of observations for each group
192
+ """
193
+ top_k_sums = {g: np.zeros(top_k) for g in self.groups}
194
+ top_k_second_moments = {g: np.zeros((top_k, top_k)) for g in self.groups}
195
+ rem_sums = {g: np.zeros(self.n_outcomes - top_k) for g in self.groups}
196
+ rem_second_moments = {g: np.zeros(self.n_outcomes - top_k) for g in self.groups}
197
+ Ng = {g: 0 for g in self.groups}
198
+
199
+ for y, x_dict in tqdm(self.loader, desc="Estimating top-k copula covariance"):
200
+ group_data = x_dict.get("group")
201
+ memberships = group_data.cpu().numpy()
202
+ u = uniformizer(y, x_dict)
203
+ z = norm.ppf(u)
204
+
205
+ for g in self.groups:
206
+ mask = memberships[:, self.group_col[g]] == 1
207
+ if not np.any(mask):
208
+ continue
209
+
210
+ z_g = z[mask]
211
+ n_g = mask.sum()
212
+
213
+ top_k_z, rem_z = z_g[:, top_k_idx], z_g[:, rem_idx]
214
+
215
+ top_k_sums[g] += top_k_z.sum(axis=0)
216
+ top_k_second_moments[g] += top_k_z.T @ top_k_z
217
+
218
+ rem_sums[g] += rem_z.sum(axis=0)
219
+ rem_second_moments[g] += (rem_z ** 2).sum(axis=0)
220
+
221
+ Ng[g] += n_g
222
+
223
+ return top_k_sums, top_k_second_moments, rem_sums, rem_second_moments, Ng
224
+
225
+ def _accumulate_full_stats(self, uniformizer:Callable) \
226
+ -> Tuple[Dict[Union[str, int], np.ndarray],
227
+ Dict[Union[str, int], np.ndarray],
228
+ Dict[Union[str, int], int]]:
229
+ """Accumulate sufficient statistics for full covariance estimation.
230
+
231
+ Args:
232
+ uniformizer (Callable): Function to convert to uniform distribution
233
+
234
+ Returns:
235
+ sums (dict): Sums of the genes for each group
236
+ second_moments (dict): Second moments of the genes for each group
237
+ Ng (dict): Number of observations for each group
238
+ """
239
+ sums = {g: np.zeros(self.n_outcomes) for g in self.groups}
240
+ second_moments = {g: np.zeros((self.n_outcomes, self.n_outcomes)) for g in self.groups}
241
+ Ng = {g: 0 for g in self.groups}
242
+
243
+ for y, x_dict in tqdm(self.loader, desc="Estimating copula covariance"):
244
+ group_data = x_dict.get("group")
245
+ memberships = group_data.cpu().numpy()
246
+
247
+ u = uniformizer(y, x_dict)
248
+ z = norm.ppf(u)
249
+
250
+ for g in self.groups:
251
+ mask = memberships[:, self.group_col[g]] == 1
252
+
253
+ if not np.any(mask):
254
+ continue
255
+
256
+ z_g = z[mask]
257
+ n_g = mask.sum()
258
+
259
+ second_moments[g] += z_g.T @ z_g
260
+ sums[g] += z_g.sum(axis=0)
261
+
262
+ Ng[g] += n_g
263
+
264
+ return sums, second_moments, Ng
265
+
266
+ def _compute_block_covariance(self, uniformizer:Callable,
267
+ top_k_idx: np.ndarray, rem_idx: np.ndarray, top_k: int) \
268
+ -> Dict[Union[str, int], CovarianceStructure]:
269
+ """Compute the covariance matrix for the top-k and remaining genes.
270
+
271
+ Args:
272
+ top_k_sums (dict): Sums of the top-k genes for each group
273
+ top_k_second_moments (dict): Second moments of the top-k genes for each group
274
+ remaining_sums (dict): Sums of the remaining genes for each group
275
+ remaining_second_moments (dict): Second moments of the remaining genes for each group
276
+ Ng (dict): Number of observations for each group
277
+
278
+ Returns:
279
+ covariance (dict): Covariance matrix for each group
280
+ """
281
+ top_k_sums, top_k_second_moments, remaining_sums, remaining_second_moments, Ng \
282
+ = self._accumulate_top_k_stats(uniformizer, top_k_idx, rem_idx, top_k)
283
+ covariance = {}
284
+ for g in self.groups:
285
+ if Ng[g] == 0:
286
+ warnings.warn(f"Group {g} has no observations, skipping")
287
+ continue
288
+ mean_top_k = top_k_sums[g] / Ng[g]
289
+ cov_top_k = top_k_second_moments[g] / Ng[g] - np.outer(mean_top_k, mean_top_k)
290
+ mean_remaining = remaining_sums[g] / Ng[g]
291
+ var_remaining = remaining_second_moments[g] / Ng[g] - mean_remaining ** 2
292
+ top_k_names = self.adata.var_names[top_k_idx]
293
+ remaining_names = self.adata.var_names[rem_idx]
294
+ covariance[g] = CovarianceStructure(
295
+ cov=cov_top_k,
296
+ modeled_names=top_k_names,
297
+ modeled_indices=top_k_idx,
298
+ remaining_var=var_remaining,
299
+ remaining_indices=rem_idx,
300
+ remaining_names=remaining_names
301
+ )
302
+ return covariance
303
+
304
+ def _compute_full_covariance(self, uniformizer:Callable) -> Dict[Union[str, int], CovarianceStructure]:
305
+ """Compute the covariance matrix for the full genes.
306
+
307
+ Args:
308
+ uniformizer (Callable): Function to convert to uniform distribution
309
+
310
+ Returns:
311
+ covariance (dict): Covariance matrix for each group
312
+ """
313
+ sums, second_moments, Ng = self._accumulate_full_stats(uniformizer)
314
+ covariance = {}
315
+ for g in self.groups:
316
+ if Ng[g] == 0:
317
+ warnings.warn(f"Group {g} has no observations, skipping")
318
+ continue
319
+ mean = sums[g] / Ng[g]
320
+ cov = second_moments[g] / Ng[g] - np.outer(mean, mean)
321
+ covariance[g] = CovarianceStructure(
322
+ cov=cov,
323
+ modeled_names=self.adata.var_names,
324
+ modeled_indices=np.arange(self.n_outcomes),
325
+ remaining_var=None,
326
+ remaining_indices=None,
327
+ remaining_names=None
328
+ )
329
+ return covariance
330
+
331
+ def _fast_normal_pseudo_obs(self, n_samples: int, cov_struct: CovarianceStructure) -> np.ndarray:
332
+ """Sample pseudo-observations from the covariance structure.
333
+
334
+ Args:
335
+ n_samples (int): Number of samples to generate
336
+ cov_struct (CovarianceStructure): The covariance structure
337
+
338
+ Returns:
339
+ np.ndarray: Pseudo-observations with shape (n_samples, total_genes)
340
+ """
341
+ u = np.zeros((n_samples, cov_struct.total_genes))
342
+
343
+ z_modeled = np.random.multivariate_normal(
344
+ mean=np.zeros(cov_struct.num_modeled_genes),
345
+ cov=cov_struct.cov.values,
346
+ size=n_samples
347
+ )
348
+
349
+ z_remaining = np.random.normal(
350
+ loc=0,
351
+ scale=cov_struct.remaining_var.values ** 0.5,
352
+ size=(n_samples, cov_struct.num_remaining_genes)
353
+ )
354
+
355
+ normal_distn_modeled = norm(0, np.diag(cov_struct.cov.values) ** 0.5)
356
+ u[:, cov_struct.modeled_indices] = normal_distn_modeled.cdf(z_modeled)
357
+
358
+ normal_distn_remaining = norm(0, cov_struct.remaining_var.values ** 0.5)
359
+ u[:, cov_struct.remaining_indices] = normal_distn_remaining.cdf(z_remaining)
360
+
361
+ return u
362
+
363
+ def _normal_pseudo_obs(self, n_samples: int, cov_struct: CovarianceStructure) -> np.ndarray:
364
+ """Sample pseudo-observations from the covariance structure.
365
+
366
+ Args:
367
+ n_samples (int): Number of samples to generate
368
+ cov_struct (CovarianceStructure): The covariance structure
369
+
370
+ Returns:
371
+ np.ndarray: Pseudo-observations with shape (n_samples, total_genes)
372
+ """
373
+ u = np.zeros((n_samples, cov_struct.total_genes))
374
+ z = np.random.multivariate_normal(
375
+ mean=np.zeros(cov_struct.total_genes),
376
+ cov=cov_struct.cov.values,
377
+ size=n_samples
378
+ )
379
+
380
+ normal_distn = norm(0, np.diag(cov_struct.cov.values) ** 0.5)
381
+ u = normal_distn.cdf(z)
382
+
383
+ return u
@@ -1,8 +1,10 @@
1
1
  from typing import Union, Sequence
2
2
  import numpy as np
3
+ import pandas as pd
3
4
  import re
4
5
  import torch
5
6
  import copy
7
+ from .copula import CovarianceStructure
6
8
 
7
9
 
8
10
  def nullify(sim, row_pattern: str, col_pattern: str, param: str):
@@ -31,41 +33,33 @@ def amplify(sim, factor: float, row_pattern: str, col_pattern: str, param: str):
31
33
 
32
34
 
33
35
  def decorrelate(sim, row_pattern: str, col_pattern: str, group: Union[str, None] = None):
34
- """Zero out selected off-diagonal entries of a covariance."""
35
- sim = copy.deepcopy(sim)
36
- def _apply_to_df(df):
37
- m1 = data_frame_mask(df, ".", col_pattern)
38
- m2 = data_frame_mask(df, row_pattern, ".")
39
- mask = (m1 | m2)
40
- np.fill_diagonal(mask, False)
41
- df.values[mask] = 0
42
-
43
- cov = sim.parameters["copula"]
44
- _apply_to_groups(cov, group, _apply_to_df)
45
- return sim
36
+ """Zero out selected off-diagonal entries of a covariance.
37
+ """
38
+ decorr_sim = copy.deepcopy(sim)
39
+ decorr_sim.copula.decorrelate(row_pattern, col_pattern, group)
40
+ return decorr_sim
46
41
 
47
42
 
48
43
  def correlate(sim, factor: float, row_pattern: str, col_pattern: str, group: Union[str, None] = None):
49
44
  """Multiply selected off-diagonal entries by factor."""
50
- sim = copy.deepcopy(sim)
51
- def _apply_to_df(df):
52
- m1 = data_frame_mask(df, ".", col_pattern)
53
- m2 = data_frame_mask(df, row_pattern, ".")
54
- mask = (m1 | m2)
55
- np.fill_diagonal(mask, False)
56
- df.values[mask] = df.values[mask] * factor
57
-
58
- cov = sim.parameters["copula"]
59
- _apply_to_groups(cov, group, _apply_to_df)
60
- return sim
45
+ corr_sim = copy.deepcopy(sim)
46
+ corr_sim.copula.correlate(factor, row_pattern, col_pattern, group)
47
+ return corr_sim
61
48
 
62
49
 
63
- def replace_param(sim, path: Sequence[str], new_param):
50
+ def replace_param(sim, path: Sequence[str], new_param: Union[np.ndarray, pd.DataFrame, CovarianceStructure]):
64
51
  """Substitute a new parameter for an old one.
65
52
 
66
53
  Use the path to the parameter starting from sim.parameters to identify the
67
54
  parameter to transform. Examples: ['marginal','mean'] or
68
55
  ['copula','group_name']
56
+
57
+ Args:
58
+ sim (Simulator): The simulator object.
59
+ path (Sequence[str]): The path to the parameter to transform.
60
+ new_param (np.ndarray): The new parameter to substitute.
61
+ For replacing a covariance structure, new_param could be a numpy array of shape (n_genes, n_genes)
62
+ or a CovarianceStructure object defined by the user.
69
63
  """
70
64
  sim = copy.deepcopy(sim)
71
65
  if path[0] == "marginal":
@@ -74,13 +68,16 @@ def replace_param(sim, path: Sequence[str], new_param):
74
68
  _update_marginal_param(sim, param, mat)
75
69
 
76
70
  if path[0] == "copula":
77
- key = path[1]
78
- cov = sim.parameters["copula"]
79
- if isinstance(cov, dict):
80
- cov[key] = new_param
71
+ if isinstance(new_param, np.ndarray):
72
+ sim.parameters["copula"][path[1]] = CovarianceStructure(new_param,
73
+ modeled_names=sim.adata.var_names)
74
+ elif isinstance(new_param, pd.DataFrame):
75
+ sim.parameters["copula"][path[1]] = CovarianceStructure(new_param.values,
76
+ modeled_names=new_param.index)
77
+ elif isinstance(new_param, CovarianceStructure):
78
+ sim.parameters["copula"][path[1]] = new_param
81
79
  else:
82
- sim.parameters["copula"] = new_param
83
-
80
+ raise ValueError(f"new_param must be a numpy array or a CovarianceStructure object, got {type(new_param)}")
84
81
  return sim
85
82
 
86
83
 
@@ -14,18 +14,80 @@ def glm_sample_factory(sample_array):
14
14
  return sampler
15
15
 
16
16
  def gaussian_copula_pseudo_obs(N, G, sigma, groups):
17
+
18
+ # Import here to avoid circular imports
19
+ from ..estimators.gaussian_copula_factory import FastCovarianceStructure
20
+
17
21
  u = np.zeros((N, G))
18
22
 
19
23
  # cycle across groups
20
24
  for group, ix in groups.items():
25
+ # If sigma is not a dict, then every group shares the same sigma
21
26
  if type(sigma) is not dict:
22
27
  sigma = {group: sigma}
28
+
29
+ group_sigma = sigma[group]
30
+
31
+ # Handle FastCovarianceStructure
32
+ if isinstance(group_sigma, FastCovarianceStructure):
33
+ u[ix] = _fast_copula_pseudo_obs(len(ix), group_sigma)
34
+ else:
35
+ # Traditional full covariance matrix approach
36
+ z = np.random.multivariate_normal(
37
+ mean=np.zeros(G), cov=group_sigma, size=len(ix)
38
+ )
39
+ normal_distn = norm(0, np.diag(group_sigma ** 0.5))
40
+ u[ix] = normal_distn.cdf(z)
41
+ return u
23
42
 
24
- z = np.random.multivariate_normal(
25
- mean=np.zeros(G), cov=sigma[group], size=len(ix)
43
+
44
+ def _fast_copula_pseudo_obs(n_samples, fast_cov_struct):
45
+ """
46
+ Efficient pseudo-observation generation using FastCovarianceStructure.
47
+
48
+ This function separately samples:
49
+ 1. Top-k genes using full multivariate normal with their covariance matrix
50
+ 2. Remaining genes using independent normal with their individual variances
51
+
52
+ Parameters:
53
+ -----------
54
+ n_samples : int
55
+ Number of samples to generate for this group
56
+ fast_cov_struct : FastCovarianceStructure
57
+ Structure containing top-k covariance and remaining variances
58
+
59
+ Returns:
60
+ --------
61
+ np.ndarray : Pseudo-observations with shape (n_samples, total_genes)
62
+ """
63
+ u = np.zeros((n_samples, fast_cov_struct.total_genes))
64
+
65
+ # Sample top-k genes with full covariance
66
+ if fast_cov_struct.top_k > 0:
67
+ z_top_k = np.random.multivariate_normal(
68
+ mean=np.zeros(fast_cov_struct.top_k),
69
+ cov=fast_cov_struct.top_k_cov,
70
+ size=n_samples
71
+ )
72
+
73
+ # Convert to uniform via marginal CDFs
74
+ top_k_std = np.sqrt(np.diag(fast_cov_struct.top_k_cov))
75
+ normal_distn_top_k = norm(0, top_k_std)
76
+ u[:, fast_cov_struct.top_k_indices] = normal_distn_top_k.cdf(z_top_k)
77
+
78
+ # Sample remaining genes independently
79
+ if len(fast_cov_struct.remaining_indices) > 0:
80
+ remaining_std = np.sqrt(fast_cov_struct.remaining_var)
81
+ z_remaining = np.random.normal(
82
+ loc=0,
83
+ scale=remaining_std,
84
+ size=(n_samples, len(fast_cov_struct.remaining_indices))
26
85
  )
27
- normal_distn = norm(0, np.diag(sigma[group] ** 0.5))
28
- u[ix] = normal_distn.cdf(z)
86
+
87
+ # Convert to uniform via marginal CDFs
88
+ normal_distn_remaining = norm(0, remaining_std)
89
+ u[:, fast_cov_struct.remaining_indices] = normal_distn_remaining.cdf(z_remaining)
90
+
29
91
  return u
30
92
 
31
93
 
@@ -18,7 +18,7 @@ def nullify(params: dict, id: str, mask: Union[np.array, None] = None) -> dict:
18
18
  null_params = nullify(params, "beta", mask)
19
19
  """
20
20
  if mask is None:
21
- mask = np.ones(params[id].shape)
21
+ mask = np.ones(params[id].shape, dtype=bool)
22
22
 
23
23
  result = deepcopy(params)
24
24
  result[id][mask] = 0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: scdesigner
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: Interactive simulation for rigorous and transparent multi-omics analysis.
5
5
  Project-URL: Homepage, https://github.com/krisrs1128/scDesigner/
6
6
  Project-URL: Issues, https://github.com/krisrs1128/scDesigner/Issues/
@@ -11,7 +11,6 @@ Classifier: Programming Language :: Python :: 3
11
11
  Requires-Python: >=3.8
12
12
  Requires-Dist: anndata
13
13
  Requires-Dist: formulaic
14
- Requires-Dist: lightning
15
14
  Requires-Dist: numpy
16
15
  Requires-Dist: pandas
17
16
  Requires-Dist: rich
@@ -6,14 +6,14 @@ scdesigner/data/sparse.py,sha256=lMp8gI8sq_fUTj3HiAmia0YiCdMRS1QGyyRePmuq3zY,138
6
6
  scdesigner/diagnose/__init__.py,sha256=XRBlc0_ns9Tbwq-aMWMIv1c-tcTOgRFQBywBFV46Gt4,3579
7
7
  scdesigner/diagnose/aic_bic.py,sha256=9GmtxdEXbbmCvL__pm7jSvQeeCb8KjwawY3UebahVBs,5022
8
8
  scdesigner/diagnose/plot.py,sha256=JP1vLbZVnMs171aOBawKbbCUgGiwIfCGKWrQyNh7Y2s,7225
9
- scdesigner/estimators/__init__.py,sha256=XqtTCAE02FyxMJkbnlg-qi-C_eF-OoOzP2wqa3MJEJ8,907
9
+ scdesigner/estimators/__init__.py,sha256=TDkbc25TvXJp5O4_U2QipM8tmEoJ0HpNG53rlYthsBk,1221
10
10
  scdesigner/estimators/bernoulli.py,sha256=vnIETFbHjXErgkKps4ImgeM5W6nta0c-8-ZetYpqK8g,3324
11
11
  scdesigner/estimators/gaussian.py,sha256=-CKumYQWw8KbshdCmCdFgpcH5dNbqY-3vNTXCyjJbqc,4311
12
- scdesigner/estimators/gaussian_copula_factory.py,sha256=tAX4jABUWNmiheDot74k7yuRlviXwZ0iO7meECbotEM,5391
12
+ scdesigner/estimators/gaussian_copula_factory.py,sha256=yeIx_C2fN2jiotTkJstV4pjQYW4aO6RJ-jMuuXR2HVs,14054
13
13
  scdesigner/estimators/glm_factory.py,sha256=tjVlEfJwBPK_Vk4G4P_eTYnBqeI8Y0r5_u_iSYB7qfA,2618
14
- scdesigner/estimators/negbin.py,sha256=JCVDIqFe-IfgYSmFz8gAPXFb4PiIaYiFTqNlU42iFZg,4812
14
+ scdesigner/estimators/negbin.py,sha256=4q_XZOZA8gYHXOOY5Mz17cexs-kzadqB9JlxqWnZ7VQ,5639
15
15
  scdesigner/estimators/pnmf.py,sha256=-0WDbwh0uU4V-eyl_0uQ3qLZpT2IuEo9Pa-GEC6vJYI,4781
16
- scdesigner/estimators/poisson.py,sha256=-B-xXX-XRS22PQUkOTCFiT8HlWtKIXN2WCL_EJtb83w,3365
16
+ scdesigner/estimators/poisson.py,sha256=FMYtkip2NDYmmkyEuO_9hSuM6gtM-AZpPa9GwFrFfD8,4174
17
17
  scdesigner/estimators/zero_inflated_negbin.py,sha256=OeMbXf20wzhJUmll2meMXBdLfeHGSya2hikmXNRUKI4,7959
18
18
  scdesigner/estimators/zero_inflated_poisson.py,sha256=2lrD2H4QMBxhNbak3gDUHVxPBURLS16IyfN5eCEbq9Q,3364
19
19
  scdesigner/format/__init__.py,sha256=PR12wZFvixIqHEd--d1oZkuj6o8tAQx-rgnpUKkr03I,179
@@ -21,20 +21,20 @@ scdesigner/format/format.py,sha256=WLsGnfeM52Mg3fhKHwPx0XbkYJSXfehu2_HmQfUpHdY,5
21
21
  scdesigner/format/print.py,sha256=HK3yLQcFw-f5-nSxMy52bD8Mixw_xxAYV2W9K4_ULwg,794
22
22
  scdesigner/minimal/__init__.py,sha256=IdMK1a2iiYyJ1gsWoupc2_3wyEu3Udjbm_iC0U724kM,378
23
23
  scdesigner/minimal/bernoulli.py,sha256=KvkXiS5aOYIT_L_xymaAWd92ZJyRXZowTcOv_b1RsJo,2304
24
- scdesigner/minimal/composite.py,sha256=H-OzRUzv52F49zw3eM9Bz6HPC4deXQDNr9zTkR_7Xz4,4548
25
- scdesigner/minimal/copula.py,sha256=VgjOuKgNhKgSnGEt8GXCdjUfIhucymHFTYmAtCn4c5U,1105
24
+ scdesigner/minimal/composite.py,sha256=lOafmycSpW5U7Yrmm_NgRJx8Y9GvkuS0BVf8IGFse9c,4536
25
+ scdesigner/minimal/copula.py,sha256=nNnK9Pxrp4-jlYWFY4NBt0hiUXGVHjMQMH5P5EhxLVw,8884
26
26
  scdesigner/minimal/formula.py,sha256=VFTadNxn-2elBvCkD_yuyh9O-vDZFiitSqjJkZYJy3Y,879
27
27
  scdesigner/minimal/gaussian.py,sha256=CKsluKk_XduYMNNUZTtENzH3wf4iTRGRc4L2nzq7qYw,2351
28
28
  scdesigner/minimal/kwargs.py,sha256=a32BLKNBj7ont5fR_uUpppYwel55dpP3fqaPeVbmCKM,979
29
- scdesigner/minimal/loader.py,sha256=gufdcWWKo7MyZQ5e7lzia96ZDKV5Pz8v0WwhZGQEeYM,5947
30
- scdesigner/minimal/marginal.py,sha256=mxGZH6walwNY8HUqYxp47chPrIysiLEMRacsqAvCxWI,5425
31
- scdesigner/minimal/negbin.py,sha256=RvReWG50h3LuBBfrHv34AFylcZ8TTjV5WOugR5tBTCg,2572
29
+ scdesigner/minimal/loader.py,sha256=N6bFyRkJGkpi1NIHU9gwfJTbBP_Z51p6tX2ymxML5Fk,7483
30
+ scdesigner/minimal/marginal.py,sha256=5SB1biCpO8PzMb_--mzBdpTHrnbTEt0BMp7188afxd0,5790
31
+ scdesigner/minimal/negbin.py,sha256=NorV0CQ_wmKwTWuWd-Y1mE3gWYmVPHgmjp3pgI86bEI,2563
32
32
  scdesigner/minimal/positive_nonnegative_matrix_factorization.py,sha256=oMIC1aqdH4Cgn2pU446lh2-H3axEe-FSC_ZNcgSBrdk,7445
33
- scdesigner/minimal/scd3.py,sha256=8D8amY7zqSBgyRkrTvYzaf7r9i6mHNtZCCqeopcFbYY,3047
34
- scdesigner/minimal/scd3_instances.py,sha256=7Jb45VzEkpYegxbzqVlqpb-zMkIXebdiIMEWdZZnRzc,1993
33
+ scdesigner/minimal/scd3.py,sha256=4EKi6Wonc0LRfWjLdvn8HMhbuR-c5kZbmVNLGbJd7mw,3083
34
+ scdesigner/minimal/scd3_instances.py,sha256=vbCbF_SbCqyCUdi_STwV2ufn9ouoEk14_D9g4Wc9O14,1969
35
35
  scdesigner/minimal/simulator.py,sha256=DmeT_uXswR9larJ_OuysXUt32BqIVwxS-2yYMk9PGQw,631
36
- scdesigner/minimal/standard_covariance.py,sha256=9JlAI9C8bWSGSQRFtBYnKMnoF3UghB0DG08A4piOJ4E,4956
37
- scdesigner/minimal/transform.py,sha256=FdTOsPoGk3I-M-t8Ugmqx-Znxo0avJs5Os3kCkHotIk,4919
36
+ scdesigner/minimal/standard_copula.py,sha256=_NHCKL90eR7oHmgYtZAC9dyZDBKqYj987nrUxYoDAtM,16204
37
+ scdesigner/minimal/transform.py,sha256=j2sj2vHpFi3YdK9sz4Npl05i8UPGBcDU3kT-n72u28s,5459
38
38
  scdesigner/minimal/zero_inflated_negbin.py,sha256=npDoyAGWs9n5rl3HwvAUvDG9UDG3DlwRMUGwkpeqvlc,3179
39
39
  scdesigner/predictors/__init__.py,sha256=3ycFB7ifR2y-27Kx2GhESyfcyZsJ-6ds6vxgzwrl6ss,462
40
40
  scdesigner/predictors/bernoulli.py,sha256=ln7GpOTh7nxCVX0nvARPXw0E-l3SalDwur0iUEz-SWQ,261
@@ -46,7 +46,7 @@ scdesigner/predictors/zero_inflated_poisson.py,sha256=Krr4r-9vkh7SVy3d45v226NVUY
46
46
  scdesigner/samplers/__init__.py,sha256=ns3tA_Q7jYIPgnYPfItvHCUpBV9fDimUg0dEHcNB7wE,775
47
47
  scdesigner/samplers/bernoulli.py,sha256=1GLOU69_D2sxcedmxC2dIj1MQbmX5D8i0ccvyxjTcR8,791
48
48
  scdesigner/samplers/gaussian.py,sha256=H8DwD4XzzALFguxB1uiwUsUURF_G1B6d2j4r0up-srU,943
49
- scdesigner/samplers/glm_factory.py,sha256=QUtphZEd-hWE8G0e8o98i3ojQqyvucqanrWiMN2oUew,1287
49
+ scdesigner/samplers/glm_factory.py,sha256=aQvFTgBKAk7kV-SnFJ61aSHQGwemxW9Hl4ZXBmwGJug,3607
50
50
  scdesigner/samplers/negbin.py,sha256=d9oWIDxfXl_kNOPCjYUvtVie3uQmBN6yHfxHbrQiX-I,930
51
51
  scdesigner/samplers/poisson.py,sha256=_TxotqGuVDFRXJOIiBGqRE_qJK0XkXn96clkCDCRV1A,789
52
52
  scdesigner/samplers/zero_inflated_negbin.py,sha256=Z9kUSEQL8Z-lf48wWBLAuz74FcK7mjsKmeIzyBE9ov4,1287
@@ -58,9 +58,9 @@ scdesigner/simulators/pnmf_regression.py,sha256=B_fMK7Q9D4q31HT77384EKZEST7XXVNC
58
58
  scdesigner/transform/__init__.py,sha256=cuLIP0_tocIA3dupO7npH2mLFpL1ApLZlLWPudzPt6M,236
59
59
  scdesigner/transform/amplify.py,sha256=aNxpuyoDpXI4xK5FyCaGKqLqUh4SxFCZhwZeW8XonEQ,351
60
60
  scdesigner/transform/mask.py,sha256=z-NQ6xcnEzCFWvlfCRhQOrE-TWvtAwQ6Cs8KBPahBSk,1032
61
- scdesigner/transform/nullify.py,sha256=OjeS9UJA1Cm8LImuGFytIKnA-Oy_qXY_C9NBFMjyjNQ,780
61
+ scdesigner/transform/nullify.py,sha256=pEtYNDVT2Z_BmVc5CKl3CoxB37KOvqMReoQGnvFYMKE,792
62
62
  scdesigner/transform/split.py,sha256=AK3mU52DHSagdyW-d79tsZ12zMKc5xMF_MockUUvciE,741
63
63
  scdesigner/transform/substitute.py,sha256=pozV7IVJLyUzJVKeaSX86v0bl9foSFIAQpZ0oC18xak,326
64
- scdesigner-0.0.3.dist-info/METADATA,sha256=r2NUcpEnqH9rvuRj3N-qv0yXU7e3yx4fKbojR_RLYvc,766
65
- scdesigner-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
66
- scdesigner-0.0.3.dist-info/RECORD,,
64
+ scdesigner-0.0.5.dist-info/METADATA,sha256=SykpXsuJatehhOj2nFtqVrW5o82d5JEij4r6ndCEf_8,741
65
+ scdesigner-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
66
+ scdesigner-0.0.5.dist-info/RECORD,,