scdesigner 0.0.5__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. scdesigner/base/__init__.py +8 -0
  2. scdesigner/base/copula.py +416 -0
  3. scdesigner/base/marginal.py +391 -0
  4. scdesigner/base/simulator.py +59 -0
  5. scdesigner/copulas/__init__.py +8 -0
  6. scdesigner/copulas/standard_copula.py +645 -0
  7. scdesigner/datasets/__init__.py +5 -0
  8. scdesigner/datasets/pancreas.py +39 -0
  9. scdesigner/distributions/__init__.py +19 -0
  10. scdesigner/{minimal → distributions}/bernoulli.py +42 -14
  11. scdesigner/distributions/gaussian.py +114 -0
  12. scdesigner/distributions/negbin.py +121 -0
  13. scdesigner/distributions/negbin_irls.py +72 -0
  14. scdesigner/distributions/negbin_irls_funs.py +456 -0
  15. scdesigner/distributions/poisson.py +88 -0
  16. scdesigner/{minimal → distributions}/zero_inflated_negbin.py +39 -10
  17. scdesigner/distributions/zero_inflated_poisson.py +103 -0
  18. scdesigner/simulators/__init__.py +24 -28
  19. scdesigner/simulators/composite.py +239 -0
  20. scdesigner/simulators/positive_nonnegative_matrix_factorization.py +477 -0
  21. scdesigner/simulators/scd3.py +486 -0
  22. scdesigner/transform/__init__.py +8 -6
  23. scdesigner/{minimal → transform}/transform.py +1 -1
  24. scdesigner/{minimal → utils}/kwargs.py +4 -1
  25. {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/METADATA +1 -1
  26. scdesigner-0.0.10.dist-info/RECORD +28 -0
  27. {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/WHEEL +1 -1
  28. scdesigner/data/__init__.py +0 -16
  29. scdesigner/data/formula.py +0 -137
  30. scdesigner/data/group.py +0 -123
  31. scdesigner/data/sparse.py +0 -39
  32. scdesigner/diagnose/__init__.py +0 -65
  33. scdesigner/diagnose/aic_bic.py +0 -119
  34. scdesigner/diagnose/plot.py +0 -242
  35. scdesigner/estimators/__init__.py +0 -32
  36. scdesigner/estimators/bernoulli.py +0 -85
  37. scdesigner/estimators/gaussian.py +0 -121
  38. scdesigner/estimators/gaussian_copula_factory.py +0 -367
  39. scdesigner/estimators/glm_factory.py +0 -75
  40. scdesigner/estimators/negbin.py +0 -153
  41. scdesigner/estimators/pnmf.py +0 -160
  42. scdesigner/estimators/poisson.py +0 -124
  43. scdesigner/estimators/zero_inflated_negbin.py +0 -195
  44. scdesigner/estimators/zero_inflated_poisson.py +0 -85
  45. scdesigner/format/__init__.py +0 -4
  46. scdesigner/format/format.py +0 -20
  47. scdesigner/format/print.py +0 -30
  48. scdesigner/minimal/__init__.py +0 -17
  49. scdesigner/minimal/composite.py +0 -119
  50. scdesigner/minimal/copula.py +0 -205
  51. scdesigner/minimal/formula.py +0 -23
  52. scdesigner/minimal/gaussian.py +0 -65
  53. scdesigner/minimal/loader.py +0 -211
  54. scdesigner/minimal/marginal.py +0 -154
  55. scdesigner/minimal/negbin.py +0 -73
  56. scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -231
  57. scdesigner/minimal/scd3.py +0 -96
  58. scdesigner/minimal/scd3_instances.py +0 -50
  59. scdesigner/minimal/simulator.py +0 -25
  60. scdesigner/minimal/standard_copula.py +0 -383
  61. scdesigner/predictors/__init__.py +0 -15
  62. scdesigner/predictors/bernoulli.py +0 -9
  63. scdesigner/predictors/gaussian.py +0 -16
  64. scdesigner/predictors/negbin.py +0 -17
  65. scdesigner/predictors/poisson.py +0 -12
  66. scdesigner/predictors/zero_inflated_negbin.py +0 -18
  67. scdesigner/predictors/zero_inflated_poisson.py +0 -18
  68. scdesigner/samplers/__init__.py +0 -23
  69. scdesigner/samplers/bernoulli.py +0 -27
  70. scdesigner/samplers/gaussian.py +0 -25
  71. scdesigner/samplers/glm_factory.py +0 -103
  72. scdesigner/samplers/negbin.py +0 -25
  73. scdesigner/samplers/poisson.py +0 -25
  74. scdesigner/samplers/zero_inflated_negbin.py +0 -40
  75. scdesigner/samplers/zero_inflated_poisson.py +0 -16
  76. scdesigner/simulators/composite_regressor.py +0 -72
  77. scdesigner/simulators/glm_simulator.py +0 -167
  78. scdesigner/simulators/pnmf_regression.py +0 -61
  79. scdesigner/transform/amplify.py +0 -14
  80. scdesigner/transform/mask.py +0 -33
  81. scdesigner/transform/nullify.py +0 -25
  82. scdesigner/transform/split.py +0 -23
  83. scdesigner/transform/substitute.py +0 -14
  84. scdesigner-0.0.5.dist-info/RECORD +0 -66
@@ -0,0 +1,103 @@
1
+ from ..data.formula import standardize_formula
2
+ from ..base.marginal import GLMPredictor, Marginal
3
+ from ..data.loader import _to_numpy
4
+ from typing import Union, Dict, Optional, Tuple
5
+ import torch
6
+ import numpy as np
7
+ from scipy.stats import poisson, bernoulli
8
+
9
+ class ZeroInflatedPoisson(Marginal):
10
+ """Zero-Inflated Poisson marginal estimator
11
+
12
+ This subclass models counts with an explicit zero-inflation component.
13
+ For each feature j the observation follows a mixture: with probability
14
+ `pi_j(x)` the value is an extra zero, otherwise the count is drawn from
15
+ a Poisson distribution with mean `mu_j(x)`. Both `mu_j(x)` and the
16
+ inflation probability `pi_j(x)` may depend on covariates `x` through the
17
+ `formula` argument.
18
+
19
+ The allowed formula keys are 'mean' and 'zero_inflation'. If a string
20
+ formula is supplied it is taken to specify the `mean` by default.
21
+
22
+ Examples
23
+ --------
24
+ >>> from scdesigner.distributions import ZeroInflatedPoisson
25
+ >>> from scdesigner.datasets import pancreas
26
+ >>>
27
+ >>> sim = ZeroInflatedPoisson(formula={"mean": "~ pseudotime", "zero_inflation": "~ pseudotime"})
28
+ >>> sim.setup_data(pancreas)
29
+ >>> sim.fit(max_epochs=1, verbose=False)
30
+ >>>
31
+ >>> # evaluate p(y | x) and model parameters
32
+ >>> y, x = next(iter(sim.loader))
33
+ >>> l = sim.likelihood((y, x))
34
+ >>> y_hat = sim.predict(x)
35
+ >>>
36
+ >>> # convert to quantiles and back
37
+ >>> u = sim.uniformize(y, x)
38
+ >>> x_star = sim.invert(u, x)
39
+ """
40
+ def __init__(self, formula: Union[Dict, str]):
41
+ formula = standardize_formula(formula, allowed_keys=['mean', 'zero_inflation'])
42
+ super().__init__(formula)
43
+
44
+ def setup_optimizer(
45
+ self,
46
+ optimizer_class: Optional[callable] = torch.optim.Adam,
47
+ **optimizer_kwargs,
48
+ ):
49
+ if self.loader is None:
50
+ raise RuntimeError("self.loader is not set (call setup_data first)")
51
+
52
+ link_funs = {
53
+ "mean": torch.exp,
54
+ "zero_inflation": torch.sigmoid,
55
+ }
56
+ def nll(batch):
57
+ return -self.likelihood(batch).sum()
58
+ self.predict = GLMPredictor(
59
+ n_outcomes=self.n_outcomes,
60
+ feature_dims=self.feature_dims,
61
+ link_fns=link_funs,
62
+ loss_fn=nll,
63
+ optimizer_class=optimizer_class,
64
+ optimizer_kwargs=optimizer_kwargs
65
+ )
66
+
67
+ def likelihood(self, batch) -> torch.Tensor:
68
+ """Compute the log-likelihood"""
69
+ y, x = batch
70
+ params = self.predict(x)
71
+ mu = params.get("mean")
72
+ pi = params.get("zero_inflation")
73
+
74
+ poisson_loglikelihood = y * torch.log(mu + 1e-10) - mu - torch.lgamma(y + 1)
75
+ return torch.log(
76
+ pi * (y == 0) + (1 - pi) * torch.exp(poisson_loglikelihood) + 1e-10
77
+ )
78
+
79
+ def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
80
+ """Invert pseudoobservations."""
81
+ mu, pi, u = self._local_params(x, u)
82
+ y = poisson(mu).ppf(u)
83
+ delta = bernoulli(1 - pi).ppf(u)
84
+ return torch.from_numpy(y * delta).float()
85
+
86
+ def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6) -> torch.Tensor:
87
+ """Return uniformized pseudo-observations for counts y given covariates x."""
88
+ # cdf values using scipy's parameterization
89
+ mu, pi, y = self._local_params(x, y)
90
+ nb_distn = poisson(mu)
91
+ u1 = pi + (1 - pi) * nb_distn.cdf(y)
92
+ u2 = np.where(y > 0, pi + (1 - pi) * nb_distn.cdf(y-1), 0)
93
+ v = np.random.uniform(size=y.shape)
94
+ u = np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
95
+ return torch.from_numpy(u).float()
96
+
97
+ def _local_params(self, x, y=None) -> Tuple:
98
+ params = self.predict(x)
99
+ mu = params.get('mean')
100
+ pi = params.get('zero_inflation')
101
+ if y is None:
102
+ return _to_numpy(mu, pi)
103
+ return _to_numpy(mu, pi, y)
@@ -1,31 +1,27 @@
1
- from .composite_regressor import CompositeGLMSimulator
2
- from .glm_simulator import (
3
- BernoulliCopulaSimulator,
4
- BernoulliRegressionSimulator,
5
- NegBinCopulaSimulator,
6
- NegBinRegressionSimulator,
7
- PoissonCopulaSimulator,
8
- PoissonRegressionSimulator,
9
- GaussianRegressionSimulator,
10
- GaussianCopulaSimulator,
11
- ZeroInflatedNegBinCopulaSimulator,
12
- ZeroInflatedNegBinRegressionSimulator,
13
- ZeroInflatedPoissonRegressionSimulator,
1
+ """Simulator classes"""
2
+
3
+ from .scd3 import (
4
+ BernoulliCopula,
5
+ GaussianCopula,
6
+ NegBinCopula,
7
+ NegBinIRLSCopula,
8
+ PoissonCopula,
9
+ ZeroInflatedNegBinCopula,
10
+ ZeroInflatedPoissonCopula
14
11
  )
15
- from .pnmf_regression import PNMFRegressionSimulator
12
+ from .composite import CompositeCopula
13
+ from .positive_nonnegative_matrix_factorization import PositiveNMF
16
14
 
17
15
  __all__ = [
18
- "BernoulliCopulaSimulator",
19
- "BernoulliRegressionSimulator",
20
- "CompositeGLMSimulator",
21
- "GaussianRegressionSimulator",
22
- "GaussianCopulaSimulator",
23
- "NegBinCopulaSimulator",
24
- "NegBinRegressionSimulator",
25
- "PNMFRegressionSimulator",
26
- "PoissonCopulaSimulator",
27
- "PoissonRegressionSimulator",
28
- "ZeroInflatedNegBinCopulaSimulator",
29
- "ZeroInflatedNegBinRegressionSimulator",
30
- "ZeroInflatedPoissonRegressionSimulator",
31
- ]
16
+ "BernoulliCopula",
17
+ "CompositeCopula",
18
+ "GaussianCopula",
19
+ "NegBinCopula",
20
+ "NegBinCopula",
21
+ "NegBinIRLSCopula",
22
+ "NegBinInitCopula",
23
+ "PoissonCopula",
24
+ "PositiveNMF",
25
+ "ZeroInflatedNegBinCopula",
26
+ "ZeroInflatedPoissonCopula"
27
+ ]
@@ -0,0 +1,239 @@
1
+ """Composite simulator that combines multiple marginals with a Gaussian copula.
2
+
3
+ This module provides :class:`CompositeCopula`, a simulator that fits several
4
+ marginal models and then couples their dependence structure with a
5
+ :class:`~scdesigner.copulas.standard_copula.StandardCopula`.
6
+ """
7
+
8
+ from ..data.loader import obs_loader
9
+ from .scd3 import SCD3Simulator
10
+ from ..copulas.standard_copula import StandardCopula
11
+ from anndata import AnnData
12
+ from typing import Dict, Optional, List
13
+ import numpy as np
14
+ import torch
15
+
16
+ class CompositeCopula(SCD3Simulator):
17
+ """
18
+ Composite simulator: multiple marginals + a shared Gaussian copula.
19
+
20
+ The composite simulator fits each marginal model independently on a
21
+ (potentially different) subset of variables, and then fits a Gaussian
22
+ copula on the *merged* uniformized outputs from all marginals to capture
23
+ cross-feature dependence.
24
+
25
+ Each marginal is provided as a pair ``(sel, marginal)`` where:
26
+
27
+ - ``sel`` selects which variables in ``adata`` the marginal is responsible
28
+ for (e.g. a list of gene names, a single gene name).
29
+ - ``marginal`` is an object implementing the marginal simulator interface
30
+
31
+ Parameters
32
+ ----------
33
+ marginals : list
34
+ List of ``(sel, marginal)`` pairs.
35
+ copula_formula : str, optional
36
+ Formula passed to :class:`~scdesigner.copulas.standard_copula.StandardCopula`
37
+ to determine copula grouping structure (e.g. ``"group ~ 1"``). If
38
+ ``None``, uses the copula's default.
39
+
40
+ Attributes
41
+ ----------
42
+ marginals : list
43
+ The provided marginal specifications.
44
+ copula : StandardCopula
45
+ The fitted copula component.
46
+ template : AnnData or None
47
+ Training dataset (set during :meth:`fit`).
48
+ parameters : dict or None
49
+ Fitted parameters, with keys ``"marginal"`` and ``"copula"``.
50
+ merged_formula : dict or None
51
+ Merged (prefixed) formula dictionary used to construct the copula data loader.
52
+
53
+ Examples
54
+ --------
55
+ Fit two marginal models on disjoint gene sets and then fit a copula:
56
+
57
+ >>> import numpy as np
58
+ >>> import pandas as pd
59
+ >>> from anndata import AnnData
60
+ >>> from scdesigner.simulators import CompositeCopula
61
+ >>> from scdesigner.distributions import NegBin, Poisson
62
+ >>>
63
+ >>> X = np.random.poisson(1.0, size=(100, 10)).astype(float)
64
+ >>> obs = pd.DataFrame({"cell_type": np.random.choice(["A", "B"], size=100)})
65
+ >>> adata = AnnData(X=X, obs=obs)
66
+ >>> adata.var_names = [f"g{i}" for i in range(adata.n_vars)]
67
+ >>>
68
+ >>> # Example selectors: first 5 genes vs last 5 genes
69
+ >>> sel1 = adata.var_names[:5].tolist()
70
+ >>> sel2 = adata.var_names[5:].tolist()
71
+ >>> m1 = NegBin(formula={"mean": "~ cell_type", "dispersion": "~ 1"})
72
+ >>> m2 = Poisson(formula={"mean": "~ cell_type"})
73
+ >>>
74
+ >>> composite = CompositeCopula([(sel1, m1), (sel2, m2)])
75
+ >>> composite.fit(adata, batch_size=256, verbose=False)
76
+ >>> params = composite.predict(adata.obs.iloc[:3], batch_size=3)
77
+ """
78
+ def __init__(self, marginals: List,
79
+ copula_formula: Optional[str] = None) -> None:
80
+ """Create a composite simulator.
81
+
82
+ Parameters
83
+ ----------
84
+ marginals : list
85
+ List of ``(sel, marginal)`` pairs. See class docstring for details.
86
+ copula_formula : str, optional
87
+ Copula grouping formula passed to :class:`StandardCopula`.
88
+ """
89
+ self.marginals = marginals
90
+ self.copula = StandardCopula(copula_formula) if copula_formula is not None else StandardCopula()
91
+ self.template = None
92
+ self.parameters = None
93
+ self.merged_formula = None
94
+
95
+ def fit(
96
+ self,
97
+ adata: AnnData,
98
+ verbose: bool = True,
99
+ **kwargs,):
100
+ """Fit all marginals and then fit the copula on merged uniforms.
101
+
102
+ Parameters
103
+ ----------
104
+ adata : AnnData
105
+ Training dataset.
106
+ **kwargs
107
+ Additional keyword arguments forwarded to marginal setup/fit methods
108
+ and to the copula's ``setup_data`` / ``fit`` calls (e.g.
109
+ ``batch_size``).
110
+ verbose : bool, optional
111
+ Whether to print verbose output.
112
+ """
113
+ self.template = adata
114
+ merged_formula = {}
115
+
116
+ # fit each marginal model
117
+ for m in range(len(self.marginals)):
118
+ self.marginals[m][1].setup_data(adata[:, self.marginals[m][0]], **kwargs)
119
+ self.marginals[m][1].setup_optimizer(**kwargs)
120
+ self.marginals[m][1].fit(**kwargs, verbose=verbose)
121
+
122
+ # prepare formula for copula loader
123
+ f = self.marginals[m][1].formula
124
+ prefixed_f = {f"group{m}_{k}": v for k, v in f.items()}
125
+ merged_formula = merged_formula | prefixed_f
126
+
127
+ # copula simulator
128
+ self.merged_formula = merged_formula
129
+ self.copula.setup_data(adata, merged_formula, **kwargs)
130
+ self.copula.fit(self.merged_uniformize, **kwargs)
131
+ self.parameters = {
132
+ "marginal": [m[1].parameters for m in self.marginals],
133
+ "copula": self.copula.parameters
134
+ }
135
+
136
+ def merged_uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
137
+ """Produce a merged uniformized matrix for all marginals.
138
+
139
+ Delegates to each marginal's `uniformize` method and places the
140
+ result into the columns of a full matrix according to the variable
141
+ selection given in `self.marginals[m][0]`.
142
+ """
143
+ y_np = y.detach().cpu().numpy()
144
+ u = np.empty_like(y_np, dtype=float)
145
+
146
+ for m in range(len(self.marginals)):
147
+ sel = self.marginals[m][0]
148
+ ix = _var_indices(sel, self.template)
149
+
150
+ # remove the `group{m}_` prefix we used to distinguish the marginals
151
+ prefix = f"group{m}_"
152
+ cur_x = {k.removeprefix(prefix): v if k.startswith(prefix) else v for k, v in x.items()}
153
+
154
+ # slice the subset of y for this marginal and call its uniformize
155
+ y_sub = torch.from_numpy(y_np[:, ix])
156
+ u[:, ix] = self.marginals[m][1].uniformize(y_sub, cur_x)
157
+ return torch.from_numpy(u)
158
+
159
+ def predict(self, obs=None, batch_size: int = 1000, **kwargs):
160
+ """Predict marginal parameters for observations (batched).
161
+
162
+ This method constructs an internal loader for ``obs`` using the merged
163
+ (prefixed) formula dictionary, then dispatches per-marginal ``predict``
164
+ calls on each batch after stripping the prefixes.
165
+
166
+ Parameters
167
+ ----------
168
+ obs : pandas.DataFrame, optional
169
+ Observation metadata. Defaults to ``self.template.obs``.
170
+ batch_size : int, optional
171
+ Batch size for the internal observation loader.
172
+ **kwargs
173
+ Forwarded to :func:`~scdesigner.data.loader.obs_loader`.
174
+
175
+ Returns
176
+ -------
177
+ list[dict[str, np.ndarray]]
178
+ List with one element per marginal. Each element is a dict mapping
179
+ parameter names to numpy arrays, concatenated across batches.
180
+ """
181
+ # prepare an internal data loader for this obs
182
+ if obs is None:
183
+ obs = self.template.obs
184
+ loader = obs_loader(
185
+ obs,
186
+ self.merged_formula,
187
+ batch_size=batch_size,
188
+ **kwargs
189
+ )
190
+
191
+ # prepare per-marginal collectors
192
+ n_marginals = len(self.marginals)
193
+ local_pred = [[] for _ in range(n_marginals)]
194
+
195
+ # for each batch, call each marginal's predict on its subset of x
196
+ for _, x_dict in loader:
197
+ for m in range(n_marginals):
198
+ prefix = f"group{m}_"
199
+ # build cur_x where prefixed keys are unprefixed for the marginal
200
+ cur_x = {k.removeprefix(prefix): v for k, v in x_dict.items()}
201
+ params = self.marginals[m][1].predict(cur_x)
202
+ local_pred[m].append(params)
203
+
204
+ # merge batch-wise parameter dicts for each marginal and return
205
+ results = []
206
+ for m in range(n_marginals):
207
+ parts = local_pred[m]
208
+ keys = list(parts[0].keys())
209
+ results.append({k: torch.cat([d[k] for d in parts]).detach().cpu().numpy() for k in keys})
210
+
211
+ return results
212
+
213
+
214
+ def _var_indices(sel, adata: AnnData) -> np.ndarray:
215
+ """Return integer indices of ``sel`` within ``adata.var_names``.
216
+
217
+ Parameters
218
+ ----------
219
+ sel : str or list of str
220
+ The variable names to select.
221
+ adata : AnnData
222
+ The AnnData object to select variables from.
223
+
224
+ Returns
225
+ -------
226
+ np.ndarray
227
+ The integer indices of the selected variables.
228
+ """
229
+ # If sel is a single string, make it a list so we return consistent shape
230
+ single_string = False
231
+ if isinstance(sel, str):
232
+ sel = [sel]
233
+ single_string = True
234
+
235
+ idx = np.asarray(adata.var_names.get_indexer(sel), dtype=int)
236
+ if (idx < 0).any():
237
+ missing = [s for s, i in zip(sel, idx) if i < 0]
238
+ raise KeyError(f"Variables not found in adata.var_names: {missing}")
239
+ return idx if not single_string else idx.reshape(-1)