scdesigner 0.0.5__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. scdesigner/base/__init__.py +8 -0
  2. scdesigner/base/copula.py +416 -0
  3. scdesigner/base/marginal.py +391 -0
  4. scdesigner/base/simulator.py +59 -0
  5. scdesigner/copulas/__init__.py +8 -0
  6. scdesigner/copulas/standard_copula.py +645 -0
  7. scdesigner/datasets/__init__.py +5 -0
  8. scdesigner/datasets/pancreas.py +39 -0
  9. scdesigner/distributions/__init__.py +19 -0
  10. scdesigner/{minimal → distributions}/bernoulli.py +42 -14
  11. scdesigner/distributions/gaussian.py +114 -0
  12. scdesigner/distributions/negbin.py +121 -0
  13. scdesigner/distributions/negbin_irls.py +72 -0
  14. scdesigner/distributions/negbin_irls_funs.py +456 -0
  15. scdesigner/distributions/poisson.py +88 -0
  16. scdesigner/{minimal → distributions}/zero_inflated_negbin.py +39 -10
  17. scdesigner/distributions/zero_inflated_poisson.py +103 -0
  18. scdesigner/simulators/__init__.py +24 -28
  19. scdesigner/simulators/composite.py +239 -0
  20. scdesigner/simulators/positive_nonnegative_matrix_factorization.py +477 -0
  21. scdesigner/simulators/scd3.py +486 -0
  22. scdesigner/transform/__init__.py +8 -6
  23. scdesigner/{minimal → transform}/transform.py +1 -1
  24. scdesigner/{minimal → utils}/kwargs.py +4 -1
  25. {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/METADATA +1 -1
  26. scdesigner-0.0.10.dist-info/RECORD +28 -0
  27. {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/WHEEL +1 -1
  28. scdesigner/data/__init__.py +0 -16
  29. scdesigner/data/formula.py +0 -137
  30. scdesigner/data/group.py +0 -123
  31. scdesigner/data/sparse.py +0 -39
  32. scdesigner/diagnose/__init__.py +0 -65
  33. scdesigner/diagnose/aic_bic.py +0 -119
  34. scdesigner/diagnose/plot.py +0 -242
  35. scdesigner/estimators/__init__.py +0 -32
  36. scdesigner/estimators/bernoulli.py +0 -85
  37. scdesigner/estimators/gaussian.py +0 -121
  38. scdesigner/estimators/gaussian_copula_factory.py +0 -367
  39. scdesigner/estimators/glm_factory.py +0 -75
  40. scdesigner/estimators/negbin.py +0 -153
  41. scdesigner/estimators/pnmf.py +0 -160
  42. scdesigner/estimators/poisson.py +0 -124
  43. scdesigner/estimators/zero_inflated_negbin.py +0 -195
  44. scdesigner/estimators/zero_inflated_poisson.py +0 -85
  45. scdesigner/format/__init__.py +0 -4
  46. scdesigner/format/format.py +0 -20
  47. scdesigner/format/print.py +0 -30
  48. scdesigner/minimal/__init__.py +0 -17
  49. scdesigner/minimal/composite.py +0 -119
  50. scdesigner/minimal/copula.py +0 -205
  51. scdesigner/minimal/formula.py +0 -23
  52. scdesigner/minimal/gaussian.py +0 -65
  53. scdesigner/minimal/loader.py +0 -211
  54. scdesigner/minimal/marginal.py +0 -154
  55. scdesigner/minimal/negbin.py +0 -73
  56. scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -231
  57. scdesigner/minimal/scd3.py +0 -96
  58. scdesigner/minimal/scd3_instances.py +0 -50
  59. scdesigner/minimal/simulator.py +0 -25
  60. scdesigner/minimal/standard_copula.py +0 -383
  61. scdesigner/predictors/__init__.py +0 -15
  62. scdesigner/predictors/bernoulli.py +0 -9
  63. scdesigner/predictors/gaussian.py +0 -16
  64. scdesigner/predictors/negbin.py +0 -17
  65. scdesigner/predictors/poisson.py +0 -12
  66. scdesigner/predictors/zero_inflated_negbin.py +0 -18
  67. scdesigner/predictors/zero_inflated_poisson.py +0 -18
  68. scdesigner/samplers/__init__.py +0 -23
  69. scdesigner/samplers/bernoulli.py +0 -27
  70. scdesigner/samplers/gaussian.py +0 -25
  71. scdesigner/samplers/glm_factory.py +0 -103
  72. scdesigner/samplers/negbin.py +0 -25
  73. scdesigner/samplers/poisson.py +0 -25
  74. scdesigner/samplers/zero_inflated_negbin.py +0 -40
  75. scdesigner/samplers/zero_inflated_poisson.py +0 -16
  76. scdesigner/simulators/composite_regressor.py +0 -72
  77. scdesigner/simulators/glm_simulator.py +0 -167
  78. scdesigner/simulators/pnmf_regression.py +0 -61
  79. scdesigner/transform/amplify.py +0 -14
  80. scdesigner/transform/mask.py +0 -33
  81. scdesigner/transform/nullify.py +0 -25
  82. scdesigner/transform/split.py +0 -23
  83. scdesigner/transform/substitute.py +0 -14
  84. scdesigner-0.0.5.dist-info/RECORD +0 -66
@@ -0,0 +1,645 @@
1
+ from ..base.copula import Copula
2
+ from ..data.formula import standardize_formula
3
+ from ..utils.kwargs import DEFAULT_ALLOWED_KWARGS, _filter_kwargs
4
+ from anndata import AnnData
5
+ from scipy.stats import norm, multivariate_normal
6
+ from tqdm import tqdm
7
+ from typing import Dict, Union, Callable, Tuple
8
+ import numpy as np
9
+ import torch
10
+ from ..base.copula import CovarianceStructure
11
+ import warnings
12
+
13
+
14
+ class StandardCopula(Copula):
15
+ """
16
+ Gaussian copula model with optional group-specific covariance structures.
17
+
18
+ This implementation estimates a multivariate normal dependence structure
19
+ on latent Gaussian variables. Optionally, different covariance
20
+ matrices can be estimated for categorical groups (for example, cell
21
+ types or experimental conditions).
22
+
23
+ Parameters
24
+ ----------
25
+ formula : str or dict, optional
26
+ A formula describing how the copula depends on experimental or
27
+ biological conditions. The formula is standardized to ensure that
28
+ a ``"group"`` term is always present. By default ``"~ 1"``.
29
+
30
+ Attributes
31
+ ----------
32
+ loader : torch.utils.data.DataLoader
33
+ A data loader object is used to estimate the covariance one batch at a
34
+ time. This allows estimation of the covariance structure in a streaming
35
+ way, without having to load all data into memory.
36
+ n_outcomes : int
37
+ The number of features modeled by this marginal model. For example,
38
+ this corresponds to the number of genes being simulated.
39
+ parameters : Dict[str, CovarianceStructure]
40
+ A dictionary of CovarianceStructure objects. Each key corresponds to a
41
+ different category specified in the original formula. The covariance
42
+ structure stores the relationships among genes. It can be a standard
43
+ covariance matrix, but may also use more memory-efficient approximations
44
+ like when using CovarianceStructure with a constraint on
45
+ num_modeled_genes.
46
+ groups : list
47
+ The list of groups in the formula.
48
+ n_groups : int
49
+ The number of groups in the formula.
50
+
51
+ Examples
52
+ --------
53
+ >>> import numpy as np
54
+ >>> import scanpy as sc
55
+ >>> from scdesigner.copulas.standard_copula import StandardCopula
56
+ >>>
57
+ >>> # Load a small dataset (cells x genes) and keep only a few genes for speed
58
+ >>> adata = sc.datasets.pbmc3k()[:500, :20].copy()
59
+ >>>
60
+ >>> # Instantiate the copula with a simple group formula and set up data
61
+ >>> copula = StandardCopula("group ~ 1")
62
+ >>> copula.setup_data(adata, {"group": "~ 1"}, batch_size=256)
63
+ >>> copula.groups # groups inferred from the design matrix
64
+ ['Intercept']
65
+ >>> copula.n_outcomes # number of modeled genes
66
+ 20
67
+ >>> # Define a simple rank-based uniformizer used by fit() and likelihood()
68
+ >>> def rank_uniformizer(y, x_dict):
69
+ ... y_np = y.cpu().numpy()
70
+ ... # Convert each gene to ranks and scale to (0, 1)
71
+ ... ranks = np.argsort(np.argsort(y_np, axis=0), axis=0) + 1
72
+ ... return ranks / (y_np.shape[0] + 1.0)
73
+ >>>
74
+ >>> # Fit the Gaussian copula covariance model
75
+ >>> copula.fit(rank_uniformizer, top_k=10)
76
+ >>> isinstance(copula.parameters, dict)
77
+ True
78
+ >>> # Draw dependent uniform pseudo-observations for a batch of covariates
79
+ >>> y_batch, x_batch = next(iter(copula.loader))
80
+ >>> u = copula.pseudo_obs(x_batch)
81
+ >>> u.shape[1] == copula.n_outcomes
82
+ True
83
+ >>> # Compute per-cell log-likelihoods for the same batch
84
+ >>> ll = copula.likelihood(rank_uniformizer, (y_batch, x_batch))
85
+ >>> ll.shape[0] == y_batch.shape[0]
86
+ True
87
+ >>> # Inspect the effective number of covariance parameters
88
+ >>> n_params = copula.num_params()
89
+ >>> isinstance(n_params, int) and n_params > 0
90
+ True
91
+
92
+ """
93
+
94
+ def __init__(self, formula: Union[str, dict] = "~ 1"):
95
+ """
96
+ Initialize a :class:`StandardCopula` instance.
97
+
98
+ Parameters
99
+ ----------
100
+ formula : str, optional
101
+ Copula formula specifying categorical covariates (e.g. cell type).
102
+ The formula is processed so that a ``"group"`` predictor is present,
103
+ which is then used to estimate group-specific covariance matrices.
104
+ """
105
+ formula = standardize_formula(formula, allowed_keys=["group"])
106
+ super().__init__(formula)
107
+ self.groups = None
108
+
109
+ def setup_data(self, adata: AnnData, marginal_formula: Dict[str, str], **kwargs):
110
+ """
111
+ Set up data and design matrices for covariance estimation.
112
+
113
+ After this call, the internal loader produces batches whose
114
+ ``x_dict`` always contains a binary ``"group"`` one‑hot matrix.
115
+
116
+ Parameters
117
+ ----------
118
+ adata : AnnData
119
+ Annotated data matrix with cells in rows and features (e.g. genes)
120
+ in columns.
121
+ marginal_formula : dict of {str: str}
122
+ Mapping from parameter name to formula used for the marginal
123
+ models. This is combined with the copula formula.
124
+ **kwargs
125
+ Additional keyword arguments passed to :func:`adata_loader`
126
+ (e.g. ``batch_size``, shuffling, device options).
127
+
128
+ Raises
129
+ ------
130
+ ValueError
131
+ If the inferred ``"group"`` design matrix is not binary, i.e.
132
+ contains entries other than 0 or 1.
133
+ """
134
+ data_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS["data"])
135
+ super().setup_data(adata, marginal_formula, **data_kwargs)
136
+ _, obs_batch = next(iter(self.loader))
137
+ obs_batch_group = obs_batch.get("group")
138
+
139
+ # fill in group indexing variables
140
+ self.groups = self.loader.dataset.predictor_names["group"]
141
+ self.n_groups = len(self.groups)
142
+ self._group_col = {g: i for i, g in enumerate(self.groups)}
143
+
144
+ # check that obs_batch is a binary grouping matrix (only if group exists)
145
+ if obs_batch_group is not None:
146
+ unique_vals = torch.unique(obs_batch_group)
147
+ if not torch.all((unique_vals == 0) | (unique_vals == 1)).item():
148
+ raise ValueError(
149
+ "Only categorical groups are currently supported in copula covariance estimation."
150
+ )
151
+
152
+ def fit(self, uniformizer: Callable, **kwargs):
153
+ """
154
+ Fit the Gaussian copula covariance model.
155
+
156
+ The data are first transformed to pseudo‑Gaussian variables via the
157
+ ``uniformizer`` (PIT) and an inverse normal CDF. Depending on
158
+ ``top_k``, either a full covariance matrix is estimated for all genes,
159
+ or a block structure with an explicit covariance for the top‑``k``
160
+ most expressed genes and diagonal variances for the remainder.
161
+
162
+ Parameters
163
+ ----------
164
+ uniformizer : callable
165
+ Function with signature ``uniformizer(y, x_dict) -> np.ndarray``
166
+ (or tensor convertible to ``np.ndarray``) that converts
167
+ expression data to uniform \([0, 1]\) values.
168
+ **kwargs
169
+ Additional keyword arguments controlling the fit.
170
+
171
+ Other Parameters
172
+ ----------------
173
+ top_k : int, optional
174
+ Number of most expressed genes to model with a full covariance
175
+ block. If ``None``, a full covariance matrix is estimated for all genes.
176
+
177
+ Raises
178
+ ------
179
+ ValueError
180
+ If ``top_k`` is not a positive integer or exceeds the number
181
+ of modeled outcomes.
182
+ """
183
+ top_k = kwargs.get("top_k", None)
184
+ if top_k is not None:
185
+ if not isinstance(top_k, int):
186
+ raise ValueError("top_k must be an integer")
187
+ if top_k <= 0:
188
+ raise ValueError("top_k must be positive")
189
+ if top_k > self.n_outcomes:
190
+ raise ValueError(
191
+ f"top_k ({top_k}) cannot exceed number of outcomes "
192
+ f"({self.n_outcomes})"
193
+ )
194
+ gene_total_expression = np.array(self.adata.X.sum(axis=0)).flatten()
195
+ sorted_indices = np.argsort(gene_total_expression)
196
+ top_k_indices = sorted_indices[-top_k:]
197
+ remaining_indices = sorted_indices[:-top_k]
198
+ covariances = self._compute_block_covariance(
199
+ uniformizer, top_k_indices, remaining_indices, top_k
200
+ )
201
+ else:
202
+ covariances = self._compute_full_covariance(uniformizer)
203
+
204
+ self.parameters = covariances
205
+
206
+ def pseudo_obs(self, x_dict: Dict):
207
+ """
208
+ Sample dependent uniform pseudo‑observations from the fitted copula.
209
+
210
+ Parameters
211
+ ----------
212
+ x_dict : dict
213
+ Dictionary of covariates for the current batch. Must contain a
214
+ key ``"group"`` with a one‑hot matrix representing group
215
+ memberships for each observation.
216
+
217
+ Returns
218
+ -------
219
+ np.ndarray
220
+ Array of shape ``(n_cells, n_genes)`` containing uniform
221
+ pseudo‑observations sampled from the fitted copula.
222
+ """
223
+ # convert one-hot encoding memberships to a map
224
+ # {"group1": [indices of group 1], "group2": [indices of group 2]}
225
+ # The initialization method ensures that x_dict will always have a "group" key.
226
+ group_data = x_dict.get("group")
227
+ memberships = group_data.cpu().numpy()
228
+ group_ix = {
229
+ g: np.where(memberships[:, self._group_col[g]] == 1)[0] for g in self.groups
230
+ }
231
+
232
+ # initialize the result
233
+ u = np.zeros((len(memberships), self.n_outcomes))
234
+ parameters = self.parameters
235
+
236
+ # loop over groups and sample each part in turn
237
+ for group, cov_struct in parameters.items():
238
+ if cov_struct.remaining_var is not None:
239
+ u[group_ix[group]] = self._fast_normal_pseudo_obs(
240
+ len(group_ix[group]), cov_struct
241
+ )
242
+ else:
243
+ u[group_ix[group]] = self._normal_pseudo_obs(
244
+ len(group_ix[group]), cov_struct
245
+ )
246
+ return u
247
+
248
+ def likelihood(
249
+ self,
250
+ uniformizer: Callable,
251
+ batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]],
252
+ ):
253
+ """
254
+ Compute per‑cell log‑likelihood under the fitted copula model.
255
+
256
+ Parameters
257
+ ----------
258
+ uniformizer : callable
259
+ Function that converts expression data to uniform \([0, 1]\)
260
+ pseudo‑observations given covariates, with signature
261
+ ``uniformizer(y, x_dict)``.
262
+ batch : tuple of (torch.Tensor, dict)
263
+ A mini‑batch as returned by the internal data loader, containing:
264
+
265
+ * ``y`` (:class:`torch.Tensor`): expression data of shape
266
+ ``(n_cells, n_genes)``.
267
+ * ``x_dict`` (dict of str to :class:`torch.Tensor`): covariate
268
+ matrices, including a ``"group"`` one‑hot matrix.
269
+
270
+ Returns
271
+ -------
272
+ np.ndarray
273
+ One‑dimensional array of shape (n_cells,) with the
274
+ log‑likelihood for each observation.
275
+ """
276
+ # uniformize the observations
277
+ y, x_dict = batch
278
+ u = uniformizer(y, x_dict)
279
+ z = norm().ppf(u)
280
+
281
+ # same group manipulation as for pseudobs
282
+ parameters = self.parameters
283
+ if type(parameters) is not dict:
284
+ parameters = {self.groups[0]: parameters}
285
+
286
+ group_data = x_dict.get("group")
287
+ memberships = group_data.cpu().numpy()
288
+ group_ix = {
289
+ g: np.where(memberships[:, self._group_col[g]] == 1)[0] for g in self.groups
290
+ }
291
+
292
+ ll = np.zeros(len(z))
293
+
294
+ for group, cov_struct in parameters.items():
295
+ ix = group_ix[group]
296
+ if len(ix) > 0:
297
+ z_modeled = z[ix][:, cov_struct.modeled_indices]
298
+
299
+ ll_modeled = multivariate_normal.logpdf(
300
+ z_modeled,
301
+ np.zeros(cov_struct.num_modeled_genes),
302
+ cov_struct.cov.values,
303
+ )
304
+ if cov_struct.num_remaining_genes > 0:
305
+ z_remaining = z[ix][:, cov_struct.remaining_indices]
306
+ ll_remaining = norm.logpdf(
307
+ z_remaining,
308
+ loc=0,
309
+ scale=np.sqrt(cov_struct.remaining_var.values),
310
+ )
311
+ else:
312
+ ll_remaining = 0
313
+ ll[ix] = ll_modeled + ll_remaining
314
+ return ll
315
+
316
+ def num_params(self, **kwargs):
317
+ """
318
+ Return the effective number of covariance parameters.
319
+
320
+ Parameters
321
+ ----------
322
+ **kwargs
323
+ Currently unused, kept for consistency with other copula
324
+ implementations.
325
+
326
+ Returns
327
+ -------
328
+ int
329
+ Total number of free covariance parameters across all groups,
330
+ computed as the number of unique off‑diagonal entries in each
331
+ modeled covariance block.
332
+ """
333
+ S = self.parameters
334
+ per_group = [
335
+ ((S[g].num_modeled_genes * (S[g].num_modeled_genes - 1)) / 2)
336
+ for g in self.groups
337
+ ]
338
+ return int(sum(per_group))
339
+
340
+ def _validate_parameters(self, **kwargs):
341
+ """
342
+ Internal helper to validate keyword arguments for :meth:`fit`.
343
+ """
344
+ top_k = kwargs.get("top_k", None)
345
+ if top_k is not None:
346
+ if not isinstance(top_k, int):
347
+ raise ValueError("top_k must be an integer")
348
+ if top_k <= 0:
349
+ raise ValueError("top_k must be positive")
350
+ if top_k > self.n_outcomes:
351
+ raise ValueError(
352
+ f"top_k ({top_k}) cannot exceed number of outcomes "
353
+ f"({self.n_outcomes})"
354
+ )
355
+ return top_k
356
+
357
+ def _accumulate_top_k_stats(
358
+ self, uniformizer: Callable, top_k_idx, rem_idx, top_k
359
+ ) -> Tuple[
360
+ Dict[Union[str, int], np.ndarray],
361
+ Dict[Union[str, int], np.ndarray],
362
+ Dict[Union[str, int], np.ndarray],
363
+ Dict[Union[str, int], np.ndarray],
364
+ Dict[Union[str, int], int],
365
+ ]:
366
+ """
367
+ Accumulate sufficient statistics for top‑``k`` block covariance.
368
+
369
+ Parameters
370
+ ----------
371
+ uniformizer : callable
372
+ Function that converts each batch of counts to uniform values.
373
+ top_k_idx : np.ndarray
374
+ Array of indices corresponding to the top‑``k`` genes.
375
+ rem_idx : np.ndarray
376
+ Array of indices for the remaining genes.
377
+ top_k : int
378
+ Number of top genes modeled with a full covariance block.
379
+
380
+ Returns
381
+ -------
382
+ top_k_sums : dict
383
+ Per‑group sums of the transformed top‑``k`` genes.
384
+ top_k_second_moments : dict
385
+ Per‑group second‑moment matrices for the top‑``k`` genes.
386
+ rem_sums : dict
387
+ Per‑group sums for the remaining genes.
388
+ rem_second_moments : dict
389
+ Per‑group sums of squared values for the remaining genes.
390
+ Ng : dict
391
+ Per‑group number of observations contributing to the statistics.
392
+ """
393
+ top_k_sums = {g: np.zeros(top_k) for g in self.groups}
394
+ top_k_second_moments = {g: np.zeros((top_k, top_k)) for g in self.groups}
395
+ rem_sums = {g: np.zeros(self.n_outcomes - top_k) for g in self.groups}
396
+ rem_second_moments = {g: np.zeros(self.n_outcomes - top_k) for g in self.groups}
397
+ Ng = {g: 0 for g in self.groups}
398
+
399
+ for y, x_dict in tqdm(self.loader, desc="Estimating top-k copula covariance"):
400
+ group_data = x_dict.get("group")
401
+ memberships = group_data.cpu().numpy()
402
+ u = uniformizer(y, x_dict)
403
+ z = norm.ppf(u)
404
+
405
+ for g in self.groups:
406
+ mask = memberships[:, self._group_col[g]] == 1
407
+ if not np.any(mask):
408
+ continue
409
+
410
+ z_g = z[mask]
411
+ n_g = mask.sum()
412
+
413
+ top_k_z, rem_z = z_g[:, top_k_idx], z_g[:, rem_idx]
414
+
415
+ top_k_sums[g] += top_k_z.sum(axis=0)
416
+ top_k_second_moments[g] += top_k_z.T @ top_k_z
417
+
418
+ rem_sums[g] += rem_z.sum(axis=0)
419
+ rem_second_moments[g] += (rem_z**2).sum(axis=0)
420
+
421
+ Ng[g] += n_g
422
+
423
+ return top_k_sums, top_k_second_moments, rem_sums, rem_second_moments, Ng
424
+
425
+ def _accumulate_full_stats(
426
+ self, uniformizer: Callable
427
+ ) -> Tuple[
428
+ Dict[Union[str, int], np.ndarray],
429
+ Dict[Union[str, int], np.ndarray],
430
+ Dict[Union[str, int], int],
431
+ ]:
432
+ """
433
+ Accumulate sufficient statistics for full covariance estimation.
434
+
435
+ Parameters
436
+ ----------
437
+ uniformizer : callable
438
+ Function that converts each batch of expression counts to
439
+ uniform values.
440
+
441
+ Returns
442
+ -------
443
+ sums : dict
444
+ Per‑group sums of transformed values for all genes.
445
+ second_moments : dict
446
+ Per‑group second‑moment matrices for all genes.
447
+ Ng : dict
448
+ Per‑group number of observations contributing to the statistics.
449
+ """
450
+ sums = {g: np.zeros(self.n_outcomes) for g in self.groups}
451
+ second_moments = {
452
+ g: np.zeros((self.n_outcomes, self.n_outcomes)) for g in self.groups
453
+ }
454
+ Ng = {g: 0 for g in self.groups}
455
+
456
+ for y, x_dict in tqdm(self.loader, desc="Estimating copula covariance"):
457
+ group_data = x_dict.get("group")
458
+ memberships = group_data.cpu().numpy()
459
+
460
+ u = uniformizer(y, x_dict)
461
+ z = norm.ppf(u)
462
+
463
+ for g in self.groups:
464
+ mask = memberships[:, self._group_col[g]] == 1
465
+
466
+ if not np.any(mask):
467
+ continue
468
+
469
+ z_g = z[mask]
470
+ n_g = mask.sum()
471
+
472
+ second_moments[g] += z_g.T @ z_g
473
+ sums[g] += z_g.sum(axis=0)
474
+
475
+ Ng[g] += n_g
476
+
477
+ return sums, second_moments, Ng
478
+
479
+ def _compute_block_covariance(
480
+ self,
481
+ uniformizer: Callable,
482
+ top_k_idx: np.ndarray,
483
+ rem_idx: np.ndarray,
484
+ top_k: int,
485
+ ) -> Dict[Union[str, int], CovarianceStructure]:
486
+ """
487
+ Compute block covariance structures for top‑``k`` and remaining genes.
488
+
489
+ Parameters
490
+ ----------
491
+ uniformizer : callable
492
+ Function that converts each batch of expression counts to
493
+ uniform values.
494
+ top_k_idx : np.ndarray
495
+ Indices of the top‑``k`` genes in the original feature ordering.
496
+ rem_idx : np.ndarray
497
+ Indices of the remaining genes in the original feature ordering.
498
+ top_k : int
499
+ Number of top genes modeled with a full covariance block.
500
+
501
+ Returns
502
+ -------
503
+ dict
504
+ Mapping from group labels to :class:`CovarianceStructure`
505
+ objects that encode the estimated covariance for each group.
506
+ """
507
+ (
508
+ top_k_sums,
509
+ top_k_second_moments,
510
+ remaining_sums,
511
+ remaining_second_moments,
512
+ Ng,
513
+ ) = self._accumulate_top_k_stats(uniformizer, top_k_idx, rem_idx, top_k)
514
+ covariance = {}
515
+ for g in self.groups:
516
+ if Ng[g] == 0:
517
+ warnings.warn(f"Group {g} has no observations, skipping")
518
+ continue
519
+ mean_top_k = top_k_sums[g] / Ng[g]
520
+ cov_top_k = top_k_second_moments[g] / Ng[g] - np.outer(
521
+ mean_top_k, mean_top_k
522
+ )
523
+ mean_remaining = remaining_sums[g] / Ng[g]
524
+ var_remaining = remaining_second_moments[g] / Ng[g] - mean_remaining**2
525
+ top_k_names = self.adata.var_names[top_k_idx]
526
+ remaining_names = self.adata.var_names[rem_idx]
527
+ covariance[g] = CovarianceStructure(
528
+ cov=cov_top_k,
529
+ modeled_names=top_k_names,
530
+ modeled_indices=top_k_idx,
531
+ remaining_var=var_remaining,
532
+ remaining_indices=rem_idx,
533
+ remaining_names=remaining_names,
534
+ )
535
+ return covariance
536
+
537
+ def _compute_full_covariance(
538
+ self, uniformizer: Callable
539
+ ) -> Dict[Union[str, int], CovarianceStructure]:
540
+ """
541
+ Compute full covariance matrices for all genes.
542
+
543
+ Parameters
544
+ ----------
545
+ uniformizer : callable
546
+ Function that converts each batch of expression counts to
547
+ uniform values.
548
+
549
+ Returns
550
+ -------
551
+ dict
552
+ Mapping from group labels to :class:`CovarianceStructure`
553
+ objects, each containing a full covariance matrix for all genes.
554
+ """
555
+ sums, second_moments, Ng = self._accumulate_full_stats(uniformizer)
556
+ covariance = {}
557
+ for g in self.groups:
558
+ if Ng[g] == 0:
559
+ warnings.warn(f"Group {g} has no observations, skipping")
560
+ continue
561
+ mean = sums[g] / Ng[g]
562
+ cov = second_moments[g] / Ng[g] - np.outer(mean, mean)
563
+ covariance[g] = CovarianceStructure(
564
+ cov=cov,
565
+ modeled_names=self.adata.var_names,
566
+ modeled_indices=np.arange(self.n_outcomes),
567
+ remaining_var=None,
568
+ remaining_indices=None,
569
+ remaining_names=None,
570
+ )
571
+ return covariance
572
+
573
+ def _fast_normal_pseudo_obs(
574
+ self, n_samples: int, cov_struct: CovarianceStructure
575
+ ) -> np.ndarray:
576
+ """
577
+ Sample uniform pseudo‑observations using a block covariance structure.
578
+
579
+ Parameters
580
+ ----------
581
+ n_samples : int
582
+ Number of samples (cells) to generate.
583
+ cov_struct : CovarianceStructure
584
+ Covariance structure with a modeled block and diagonal variances
585
+ for remaining genes.
586
+
587
+ Returns
588
+ -------
589
+ np.ndarray
590
+ Array of shape ``(n_samples, total_genes)`` containing uniform
591
+ pseudo‑observations.
592
+ """
593
+ u = np.zeros((n_samples, cov_struct.total_genes))
594
+
595
+ z_modeled = np.random.multivariate_normal(
596
+ mean=np.zeros(cov_struct.num_modeled_genes),
597
+ cov=cov_struct.cov.values,
598
+ size=n_samples,
599
+ )
600
+
601
+ z_remaining = np.random.normal(
602
+ loc=0,
603
+ scale=cov_struct.remaining_var.values**0.5,
604
+ size=(n_samples, cov_struct.num_remaining_genes),
605
+ )
606
+
607
+ normal_distn_modeled = norm(0, np.diag(cov_struct.cov.values) ** 0.5)
608
+ u[:, cov_struct.modeled_indices] = normal_distn_modeled.cdf(z_modeled)
609
+
610
+ normal_distn_remaining = norm(0, cov_struct.remaining_var.values**0.5)
611
+ u[:, cov_struct.remaining_indices] = normal_distn_remaining.cdf(z_remaining)
612
+
613
+ return u
614
+
615
+ def _normal_pseudo_obs(
616
+ self, n_samples: int, cov_struct: CovarianceStructure
617
+ ) -> np.ndarray:
618
+ """
619
+ Sample uniform pseudo‑observations from a full covariance matrix.
620
+
621
+ Parameters
622
+ ----------
623
+ n_samples : int
624
+ Number of samples (cells) to generate.
625
+ cov_struct : CovarianceStructure
626
+ Covariance structure containing a full covariance matrix
627
+ for all genes.
628
+
629
+ Returns
630
+ -------
631
+ np.ndarray
632
+ Array of shape ``(n_samples, total_genes)`` containing uniform
633
+ pseudo‑observations.
634
+ """
635
+ u = np.zeros((n_samples, cov_struct.total_genes))
636
+ z = np.random.multivariate_normal(
637
+ mean=np.zeros(cov_struct.total_genes),
638
+ cov=cov_struct.cov.values,
639
+ size=n_samples,
640
+ )
641
+
642
+ normal_distn = norm(0, np.diag(cov_struct.cov.values) ** 0.5)
643
+ u = normal_distn.cdf(z)
644
+
645
+ return u
@@ -0,0 +1,5 @@
1
+ from .pancreas import fetch_pancreas
2
+
3
+ pancreas = fetch_pancreas()
4
+
5
+ __all__ = ["pancreas"]
@@ -0,0 +1,39 @@
1
+ from pathlib import Path
2
+ from typing import Optional, Union
3
+ import anndata
4
+ import joblib
5
+ import os
6
+ import urllib.request
7
+
8
+ ARCHIVE_URL = "https://figshare.com/ndownloader/files/60087086"
9
+
10
+
11
+ def _ensure_data_home(data_home: Optional[Union[str, os.PathLike]]) -> Path:
12
+ base = Path(data_home) if data_home is not None else Path.home() / ".scdesigner_data"
13
+ base.mkdir(parents=True, exist_ok=True)
14
+ return base
15
+
16
+
17
+ def fetch_pancreas(
18
+ *,
19
+ data_home: Optional[Union[str, os.PathLike]] = None,
20
+ download_if_missing: bool = True,
21
+ ) -> Optional[object]:
22
+ data_home_path = _ensure_data_home(data_home)
23
+ cache_path = data_home_path / "pancreas.joblib"
24
+ if cache_path.exists():
25
+ return joblib.load(cache_path)
26
+
27
+ if not download_if_missing:
28
+ return None
29
+
30
+ tmp_path = data_home_path / "pancreas.h5ad"
31
+ try:
32
+ urllib.request.urlretrieve(ARCHIVE_URL, str(tmp_path))
33
+ adata = anndata.read_h5ad(str(tmp_path))
34
+ joblib.dump(adata, str(cache_path), compress=6)
35
+ return adata
36
+ except:
37
+ pass
38
+ if tmp_path.exists():
39
+ tmp_path.unlink()