scdesigner 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scdesigner might be problematic. Click here for more details.

Files changed (66) hide show
  1. scdesigner/__init__.py +0 -0
  2. scdesigner/data/__init__.py +16 -0
  3. scdesigner/data/formula.py +137 -0
  4. scdesigner/data/group.py +123 -0
  5. scdesigner/data/sparse.py +39 -0
  6. scdesigner/diagnose/__init__.py +65 -0
  7. scdesigner/diagnose/aic_bic.py +119 -0
  8. scdesigner/diagnose/plot.py +242 -0
  9. scdesigner/estimators/__init__.py +27 -0
  10. scdesigner/estimators/bernoulli.py +85 -0
  11. scdesigner/estimators/gaussian.py +121 -0
  12. scdesigner/estimators/gaussian_copula_factory.py +152 -0
  13. scdesigner/estimators/glm_factory.py +75 -0
  14. scdesigner/estimators/negbin.py +129 -0
  15. scdesigner/estimators/pnmf.py +160 -0
  16. scdesigner/estimators/poisson.py +100 -0
  17. scdesigner/estimators/zero_inflated_negbin.py +195 -0
  18. scdesigner/estimators/zero_inflated_poisson.py +85 -0
  19. scdesigner/format/__init__.py +4 -0
  20. scdesigner/format/format.py +20 -0
  21. scdesigner/format/print.py +30 -0
  22. scdesigner/minimal/__init__.py +17 -0
  23. scdesigner/minimal/bernoulli.py +61 -0
  24. scdesigner/minimal/composite.py +119 -0
  25. scdesigner/minimal/copula.py +33 -0
  26. scdesigner/minimal/formula.py +23 -0
  27. scdesigner/minimal/gaussian.py +65 -0
  28. scdesigner/minimal/kwargs.py +24 -0
  29. scdesigner/minimal/loader.py +166 -0
  30. scdesigner/minimal/marginal.py +140 -0
  31. scdesigner/minimal/negbin.py +73 -0
  32. scdesigner/minimal/positive_nonnegative_matrix_factorization.py +231 -0
  33. scdesigner/minimal/scd3.py +95 -0
  34. scdesigner/minimal/scd3_instances.py +50 -0
  35. scdesigner/minimal/simulator.py +25 -0
  36. scdesigner/minimal/standard_covariance.py +124 -0
  37. scdesigner/minimal/transform.py +145 -0
  38. scdesigner/minimal/zero_inflated_negbin.py +86 -0
  39. scdesigner/predictors/__init__.py +15 -0
  40. scdesigner/predictors/bernoulli.py +9 -0
  41. scdesigner/predictors/gaussian.py +16 -0
  42. scdesigner/predictors/negbin.py +17 -0
  43. scdesigner/predictors/poisson.py +12 -0
  44. scdesigner/predictors/zero_inflated_negbin.py +18 -0
  45. scdesigner/predictors/zero_inflated_poisson.py +18 -0
  46. scdesigner/samplers/__init__.py +23 -0
  47. scdesigner/samplers/bernoulli.py +27 -0
  48. scdesigner/samplers/gaussian.py +25 -0
  49. scdesigner/samplers/glm_factory.py +41 -0
  50. scdesigner/samplers/negbin.py +25 -0
  51. scdesigner/samplers/poisson.py +25 -0
  52. scdesigner/samplers/zero_inflated_negbin.py +40 -0
  53. scdesigner/samplers/zero_inflated_poisson.py +16 -0
  54. scdesigner/simulators/__init__.py +31 -0
  55. scdesigner/simulators/composite_regressor.py +72 -0
  56. scdesigner/simulators/glm_simulator.py +167 -0
  57. scdesigner/simulators/pnmf_regression.py +61 -0
  58. scdesigner/transform/__init__.py +7 -0
  59. scdesigner/transform/amplify.py +14 -0
  60. scdesigner/transform/mask.py +33 -0
  61. scdesigner/transform/nullify.py +25 -0
  62. scdesigner/transform/split.py +23 -0
  63. scdesigner/transform/substitute.py +14 -0
  64. scdesigner-0.0.1.dist-info/METADATA +23 -0
  65. scdesigner-0.0.1.dist-info/RECORD +66 -0
  66. scdesigner-0.0.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,30 @@
1
+ import rich
2
+ import rich.table
3
+
4
+
5
+ def shorten_names(features, max_features=5):
6
+ if len(features) > max_features:
7
+ features = features[: (int(max_features - 1))] + ["..."] + features[-1:]
8
+ return f"""[{', '.join(features)}]"""
9
+
10
+
11
+ def print_simulator(margins, copula):
12
+ table = rich.table.Table(
13
+ title="[bold magenta]Simulation Plan[/bold magenta]", title_justify="left"
14
+ )
15
+ table.add_column("formula")
16
+ table.add_column("distribution")
17
+ table.add_column("features")
18
+
19
+ i = 1
20
+ for m in margins:
21
+ features, margin = m
22
+ tup = tuple(margin.to_df().iloc[0, :]) + (shorten_names(features),)
23
+ table.add_row(*tup)
24
+ i += 1
25
+
26
+ rich.print(table)
27
+ if copula is None:
28
+ rich.print("Marginal models without copula.")
29
+
30
+ return ""
@@ -0,0 +1,17 @@
1
+ from .scd3_instances import (
2
+ BernoulliCopula,
3
+ GaussianCopula,
4
+ NegBinCopula,
5
+ ZeroInflatedNegBinCopula
6
+ )
7
+ from .composite import CompositeCopula
8
+ from .positive_nonnegative_matrix_factorization import PositiveNMF
9
+
10
+ __all__ = [
11
+ "BernoulliCopula",
12
+ "CompositeCopula",
13
+ "GaussianCopula",
14
+ "NegBinCopula",
15
+ "PositiveNMF",
16
+ "ZeroInflatedNegBinCopula"
17
+ ]
@@ -0,0 +1,61 @@
1
+ from .formula import standardize_formula
2
+ from .marginal import GLMPredictor, Marginal
3
+ from .loader import _to_numpy
4
+ from typing import Union, Dict, Optional
5
+ import torch
6
+ import numpy as np
7
+ from scipy.stats import nbinom, bernoulli
8
+
9
+ class ZeroInflatedNegBin(Marginal):
10
+ """Zero-inflated negative-binomial marginal estimator"""
11
+ def __init__(self, formula: Union[Dict, str]):
12
+ formula = standardize_formula(formula, allowed_keys=['mean', 'dispersion', 'zero_inflation'])
13
+ super().__init__(formula)
14
+
15
+ def setup_optimizer(
16
+ self,
17
+ optimizer_class: Optional[callable] = torch.optim.Adam,
18
+ **optimizer_kwargs,
19
+ ):
20
+ if self.loader is None:
21
+ raise RuntimeError("self.loader is not set (call setup_data first)")
22
+
23
+ link_fns = {"mean": torch.sigmoid}
24
+ nll = lambda batch: -self.likelihood(batch).sum()
25
+ self.predict = GLMPredictor(
26
+ n_outcomes=self.n_outcomes,
27
+ feature_dims=self.feature_dims,
28
+ link_fns=link_fns,
29
+ loss_fn=nll,
30
+ optimizer_class=optimizer_class,
31
+ optimizer_kwargs=optimizer_kwargs
32
+ )
33
+
34
+ def likelihood(self, batch):
35
+ """Compute the negative log-likelihood"""
36
+ y, x = batch
37
+ params = self.predict(x)
38
+ theta = params.get("mean")
39
+ return y * torch.log(theta) + (1 - y) * torch.log(1 - theta)
40
+
41
+ def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
42
+ """Invert pseudoobservations."""
43
+ theta, u = self._local_params(x, u)
44
+ y = bernoulli(theta).ppf(u)
45
+ return torch.from_numpy(y).float()
46
+
47
+ def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6):
48
+ """Return uniformized pseudo-observations for counts y given covariates x."""
49
+ theta, y = self._local_params(x, y)
50
+ u1 = bernoulli(theta).cdf(y)
51
+ u2 = np.where(y > 0, bernoulli(theta).cdf(y - 1), 0)
52
+ v = np.random.uniform(size=y.shape)
53
+ u = np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
54
+ return torch.from_numpy(u).float()
55
+
56
+ def _local_params(self, x, y=None):
57
+ params = self.predict(x)
58
+ theta = params.get('mean')
59
+ if y is None:
60
+ return _to_numpy(theta)
61
+ return _to_numpy(theta, y)
@@ -0,0 +1,119 @@
1
+ from .loader import obs_loader
2
+ from .scd3 import SCD3Simulator
3
+ from .standard_covariance import StandardCovariance
4
+ from anndata import AnnData
5
+ from typing import Dict, Optional, List
6
+ import numpy as np
7
+ import torch
8
+
9
+ class CompositeCopula(SCD3Simulator):
10
+ def __init__(self, marginals: List,
11
+ copula_formula: Optional[str] = None) -> None:
12
+ self.marginals = marginals
13
+ self.copula = StandardCovariance(copula_formula)
14
+ self.template = None
15
+ self.parameters = None
16
+ self.merged_formula = None
17
+
18
+ def fit(
19
+ self,
20
+ adata: AnnData,
21
+ **kwargs):
22
+ """Fit the simulator"""
23
+ self.template = adata
24
+ merged_formula = {}
25
+
26
+ # fit each marginal model
27
+ for m in range(len(self.marginals)):
28
+ self.marginals[m][1].setup_data(adata[:, self.marginals[m][0]], **kwargs)
29
+ self.marginals[m][1].setup_optimizer(**kwargs)
30
+ self.marginals[m][1].fit(**kwargs)
31
+
32
+ # prepare formula for copula loader
33
+ f = self.marginals[m][1].formula
34
+ prefixed_f = {f"group{m}_{k}": v for k, v in f.items()}
35
+ merged_formula = merged_formula | prefixed_f
36
+
37
+ # copula simulator
38
+ self.merged_formula = merged_formula
39
+ self.copula.setup_data(adata, merged_formula, **kwargs)
40
+ self.copula.fit(self.merged_uniformize, **kwargs)
41
+ self.parameters = {
42
+ "marginal": [m[1].parameters for m in self.marginals],
43
+ "copula": self.copula.parameters
44
+ }
45
+
46
+ def merged_uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
47
+ """Produce a merged uniformized matrix for all marginals.
48
+
49
+ Delegates to each marginal's `uniformize` method and places the
50
+ result into the columns of a full matrix according to the variable
51
+ selection given in `self.marginals[m][0]`.
52
+ """
53
+ y_np = y.detach().cpu().numpy()
54
+ u = np.empty_like(y_np, dtype=float)
55
+
56
+ for m in range(len(self.marginals)):
57
+ sel = self.marginals[m][0]
58
+ ix = _var_indices(sel, self.template)
59
+
60
+ # remove the `group{m}_` prefix we used to distinguish the marginals
61
+ prefix = f"group{m}_"
62
+ cur_x = {k.removeprefix(prefix): v if k.startswith(prefix) else v for k, v in x.items()}
63
+
64
+ # slice the subset of y for this marginal and call its uniformize
65
+ y_sub = torch.from_numpy(y_np[:, ix])
66
+ u[:, ix] = self.marginals[m][1].uniformize(y_sub, cur_x)
67
+ return torch.from_numpy(u)
68
+
69
+ def predict(self, obs=None, batch_size: int = 1000, **kwargs):
70
+ """Predict from an obs dataframe"""
71
+ # prepare an internal data loader for this obs
72
+ if obs is None:
73
+ obs = self.template.obs
74
+ loader = obs_loader(
75
+ obs,
76
+ self.merged_formula,
77
+ batch_size=batch_size,
78
+ **kwargs
79
+ )
80
+
81
+ # prepare per-marginal collectors
82
+ n_marginals = len(self.marginals)
83
+ local_pred = [[] for _ in range(n_marginals)]
84
+
85
+ # for each batch, call each marginal's predict on its subset of x
86
+ for _, x_dict in loader:
87
+ for m in range(n_marginals):
88
+ prefix = f"group{m}_"
89
+ # build cur_x where prefixed keys are unprefixed for the marginal
90
+ cur_x = {k.removeprefix(prefix): v for k, v in x_dict.items()}
91
+ params = self.marginals[m][1].predict(cur_x)
92
+ local_pred[m].append(params)
93
+
94
+ # merge batch-wise parameter dicts for each marginal and return
95
+ results = []
96
+ for m in range(n_marginals):
97
+ parts = local_pred[m]
98
+ keys = list(parts[0].keys())
99
+ results.append({k: torch.cat([d[k] for d in parts]).detach().cpu().numpy() for k in keys})
100
+
101
+ return results
102
+
103
+
104
+ def _var_indices(sel, adata: AnnData) -> np.ndarray:
105
+ """Return integer indices of `sel` within `adata.var_names`.
106
+
107
+ Expected use: `sel` is a list (or tuple) of variable names (strings).
108
+ """
109
+ # If sel is a single string, make it a list so we return consistent shape
110
+ single_string = False
111
+ if isinstance(sel, str):
112
+ sel = [sel]
113
+ single_string = True
114
+
115
+ idx = np.asarray(adata.var_names.get_indexer(sel), dtype=int)
116
+ if (idx < 0).any():
117
+ missing = [s for s, i in zip(sel, idx) if i < 0]
118
+ raise KeyError(f"Variables not found in adata.var_names: {missing}")
119
+ return idx if not single_string else idx.reshape(-1)
@@ -0,0 +1,33 @@
1
+ from typing import Dict, Callable, Tuple
2
+ import torch
3
+ from anndata import AnnData
4
+ from .loader import adata_loader
5
+
6
+ class Copula:
7
+ def __init__(self, formula: str, **kwargs):
8
+ self.formula = formula
9
+ self.loader = None
10
+ self.n_outcomes = None
11
+ self.parameters = None
12
+
13
+ def setup_data(self, adata: AnnData, marginal_formula: Dict[str, str], batch_size: int = 1024, **kwargs):
14
+ self.adata = adata
15
+ self.formula = self.formula | marginal_formula
16
+ self.loader = adata_loader(adata, self.formula, batch_size=batch_size, **kwargs)
17
+ X_batch, _ = next(iter(self.loader))
18
+ self.n_outcomes = X_batch.shape[1]
19
+
20
+ def fit(self, uniformizer: Callable, **kwargs):
21
+ raise NotImplementedError
22
+
23
+ def pseudo_obs(self, x_dict: Dict):
24
+ raise NotImplementedError
25
+
26
+ def likelihood(self, uniformizer: Callable, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
27
+ raise NotImplementedError
28
+
29
+ def num_params(self, **kwargs):
30
+ raise NotImplementedError
31
+
32
+ def format_parameters(self):
33
+ raise NotImplementedError
@@ -0,0 +1,23 @@
1
+ from typing import Union
2
+ import warnings
3
+
4
+ def standardize_formula(formula: Union[str, dict], allowed_keys = None):
5
+ # The first element of allowed_keys should be the name of default parameter
6
+ if allowed_keys is None:
7
+ raise ValueError("Internal error: allowed_keys must be specified")
8
+ formula = {allowed_keys[0]: formula} if isinstance(formula, str) else formula
9
+
10
+ formula_keys = set(formula.keys())
11
+ allowed_keys = set(allowed_keys)
12
+
13
+ if not formula_keys & allowed_keys:
14
+ raise ValueError(f"formula must have at least one of the following keys: {allowed_keys}")
15
+
16
+ if extra_keys := formula_keys - allowed_keys:
17
+ warnings.warn(
18
+ f"Invalid formulas in dictionary will not be used: {extra_keys}",
19
+ UserWarning,
20
+ )
21
+
22
+ formula.update({k: '~ 1' for k in allowed_keys - formula_keys})
23
+ return formula
@@ -0,0 +1,65 @@
1
+ from .formula import standardize_formula
2
+ from .marginal import GLMPredictor, Marginal
3
+ from .loader import _to_numpy
4
+ from typing import Union, Dict, Optional
5
+ import torch
6
+ import numpy as np
7
+ from scipy.stats import norm
8
+
9
+ class Gaussian(Marginal):
10
+ """Gaussian marginal estimator"""
11
+ def __init__(self, formula: Union[Dict, str]):
12
+ formula = standardize_formula(formula, allowed_keys=['mean', 'sdev'])
13
+ super().__init__(formula)
14
+
15
+ def setup_optimizer(
16
+ self,
17
+ optimizer_class: Optional[callable] = torch.optim.Adam,
18
+ **optimizer_kwargs,
19
+ ):
20
+ if self.loader is None:
21
+ raise RuntimeError("self.loader is not set (call setup_data first)")
22
+
23
+ nll = lambda batch: -self.likelihood(batch).sum()
24
+ link_fns = {"mean": lambda x: x}
25
+ self.predict = GLMPredictor(
26
+ n_outcomes=self.n_outcomes,
27
+ feature_dims=self.feature_dims,
28
+ link_fns=link_fns,
29
+ loss_fn=nll,
30
+ optimizer_class=optimizer_class,
31
+ optimizer_kwargs=optimizer_kwargs
32
+ )
33
+
34
+ def likelihood(self, batch):
35
+ """Compute the negative log-likelihood"""
36
+ y, x = batch
37
+ params = self.predict(x)
38
+ mu = params.get("mean")
39
+ sigma = params.get("sdev")
40
+
41
+ # log likelihood for Gaussian
42
+ log_likelihood = -0.5 * (torch.log(2 * torch.pi * sigma ** 2) + ((y - mu) ** 2) / (sigma ** 2))
43
+ return log_likelihood
44
+
45
+ def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
46
+ """Invert pseudoobservations."""
47
+ mu, sdev, u = self._local_params(x, u)
48
+ y = norm(loc=mu, scale=sdev).ppf(u)
49
+ return torch.from_numpy(y).float()
50
+
51
+ def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6):
52
+ """Return uniformized pseudo-observations for counts y given covariates x."""
53
+ # cdf values using scipy's parameterization
54
+ mu, sdev, y = self._local_params(x, y)
55
+ u = norm.cdf(y, loc=mu, scale=sdev)
56
+ u = np.clip(u, epsilon, 1 - epsilon)
57
+ return torch.from_numpy(u).float()
58
+
59
+ def _local_params(self, x, y=None):
60
+ params = self.predict(x)
61
+ mu = params.get('mean')
62
+ sdev = params.get('sdev')
63
+ if y is None:
64
+ return _to_numpy(mu, sdev)
65
+ return _to_numpy(mu, sdev, y)
@@ -0,0 +1,24 @@
1
+ def _filter_kwargs(kwargs: dict, allowed: set) -> dict:
2
+ """Return a new dict containing only keys from kwargs that are in allowed."""
3
+ return {k: v for k, v in kwargs.items() if k in allowed}
4
+
5
+
6
+ # allowed kwargs for common components
7
+ DEFAULT_ALLOWED_KWARGS = {
8
+ 'trainer': {
9
+ 'max_epochs', 'gpus', 'devices', 'accelerator', 'precision',
10
+ 'logger', 'callbacks', 'strategy', 'num_nodes', 'limit_train_batches',
11
+ 'log_every_n_steps', 'accumulate_grad_batches'
12
+ },
13
+ 'data': {
14
+ 'chunk_size', 'batch_size', 'shuffle', 'num_workers'
15
+ },
16
+ 'optimizer': {
17
+ 'lr', 'learning_rate', 'momentum', 'weight_decay', 'eps', 'betas',
18
+ 'amsgrad', 'dampening', 'nesterov', 'alpha',
19
+ 'T_max', 'eta_min', 'step_size', 'gamma', 'milestones', 'last_epoch',
20
+ 'verbose', 'patience', 'threshold', 'cooldown',
21
+ 'optimizer_class', 'optimizer', 'scheduler_class', 'scheduler',
22
+ 'monitor', 'interval', 'frequency'
23
+ }
24
+ }
@@ -0,0 +1,166 @@
1
+ from .kwargs import DEFAULT_ALLOWED_KWARGS, _filter_kwargs
2
+ from anndata import AnnData
3
+ from formulaic import model_matrix
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from typing import Dict
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+
10
+ class AnnDataDataset(Dataset):
11
+ """Simple PyTorch Dataset for AnnData objects.
12
+
13
+ Supports optional chunked loading for backed AnnData objects. When
14
+ `chunk_size` is provided, the dataset will load contiguous slices
15
+ of rows (of size `chunk_size`) into memory once and serve individual
16
+ rows from that cached chunk. This avoids calling `to_memory()` on
17
+ a per-row basis which is expensive for large backed files.
18
+ """
19
+ def __init__(self, adata: AnnData, formula: Dict[str, str], chunk_size: int):
20
+ self.adata = adata
21
+ self.formula = formula
22
+ self.chunk_size = chunk_size
23
+
24
+ # keeping track of covariate-related information
25
+ self.obs_levels = categories(self.adata.obs)
26
+ self.obs_matrices = {}
27
+ self.predictor_names = None
28
+
29
+ # Internal cache for the currently loaded chunk
30
+ self._chunk: AnnData | None = None
31
+ self._chunk_start = 0
32
+
33
+ def __len__(self):
34
+ return len(self.adata)
35
+
36
+ def __getitem__(self, idx):
37
+ """Returns (X, obs) for the given index.
38
+
39
+ If `chunk_size` was specified the dataset will load a chunk
40
+ containing `idx` into memory (if not already cached) and
41
+ index into that chunk.
42
+ """
43
+ self._ensure_chunk_loaded(idx)
44
+ local_idx = idx - self._chunk_start
45
+ adata_slice = self._chunk[local_idx]
46
+
47
+ # Get X data, accounting for potential sparse matrices
48
+ X = adata_slice.X
49
+ if hasattr(X, 'toarray'):
50
+ X = X.toarray()
51
+
52
+ # Get obs data
53
+ obs_dict = {}
54
+ for key in self.formula.keys():
55
+ mat = self.obs_matrices.get(key)
56
+ obs_dict[key] = to_tensor(mat.values[local_idx: local_idx + 1])
57
+ return to_tensor(X), obs_dict
58
+
59
+ def _ensure_chunk_loaded(self, idx: int) -> None:
60
+ """Load the chunk that contains `idx` into the internal cache."""
61
+ start = (idx // self.chunk_size) * self.chunk_size
62
+ end = min(start + self.chunk_size, len(self.adata))
63
+
64
+ if (self._chunk is None) or not (self._chunk_start <= idx < self._chunk_start + len(self._chunk)):
65
+ # load the next chunk into memory
66
+ chunk = self.adata[start:end]
67
+ if getattr(chunk, 'isbacked', False):
68
+ chunk = chunk.to_memory()
69
+ self._chunk = chunk
70
+ self._chunk_start = start
71
+
72
+ # Compute model matrices for this chunk's `obs` so we don't need
73
+ # to keep the full obs data model matrices in memory.
74
+ obs_coded_chunk = code_levels(self._chunk.obs.copy(), self.obs_levels)
75
+ self.obs_matrices = {}
76
+ for key, f in self.formula.items():
77
+ self.obs_matrices[key] = model_matrix(f, obs_coded_chunk)
78
+
79
+ # Capture predictor (column) names from the model matrices once.
80
+ if self.predictor_names is None:
81
+ self.predictor_names = {k: list(v.columns) for k, v in self.obs_matrices.items()}
82
+
83
+
84
+ def adata_loader(adata: AnnData,
85
+ formula: Dict[str, str],
86
+ chunk_size: int = None,
87
+ batch_size: int = 1024,
88
+ shuffle: bool = False,
89
+ num_workers: int = 0,
90
+ **kwargs) -> DataLoader:
91
+ """
92
+ Create a DataLoader from AnnData that returns batches of (X, obs).
93
+ """
94
+ data_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['data'])
95
+ if chunk_size is None:
96
+ if getattr(adata, 'isbacked', False):
97
+ chunk_size = 5000
98
+ else:
99
+ chunk_size = len(adata)
100
+
101
+ dataset = AnnDataDataset(adata, formula, chunk_size)
102
+ return DataLoader(
103
+ dataset,
104
+ batch_size=batch_size,
105
+ shuffle=shuffle,
106
+ num_workers=num_workers,
107
+ collate_fn=dict_collate_fn,
108
+ **data_kwargs
109
+ )
110
+
111
+ def obs_loader(obs: pd.DataFrame, marginal_formula, **kwargs):
112
+ adata = AnnData(X=np.zeros((len(obs), 1)), obs=obs)
113
+ return adata_loader(
114
+ adata,
115
+ marginal_formula,
116
+ **kwargs
117
+ )
118
+
119
+ ################################################################################
120
+ ## Helper functions
121
+ ################################################################################
122
+
123
+ def dict_collate_fn(batch):
124
+ """
125
+ Custom collate function for handling dictionary obs tensors.
126
+ """
127
+ X_batch = torch.stack([item[0] for item in batch])
128
+ obs_batch = [item[1] for item in batch]
129
+
130
+ obs_dict = {}
131
+ for key in obs_batch[0].keys():
132
+ obs_dict[key] = torch.stack([obs[key] for obs in obs_batch])
133
+ return X_batch, obs_dict
134
+
135
+ def to_tensor(X):
136
+ # If the tensor is 2D with second dim == 1, squeeze only the first
137
+ # dim when appropriate (e.g. converting a single-row X to 1D samples)
138
+ t = torch.tensor(X, dtype=torch.float32)
139
+ if t.dim() == 2 and t.size(1) == 1:
140
+ if t.size(0) == 1:
141
+ return t.view(1)
142
+ return t
143
+ return t.squeeze()
144
+
145
+ def categories(obs):
146
+ levels = {}
147
+ for k in obs.columns:
148
+ obs_type = str(obs[k].dtype)
149
+ if obs_type in ["category", "object"]:
150
+ levels[k] = obs[k].unique()
151
+ return levels
152
+
153
+
154
+ def code_levels(obs, categories):
155
+ for k in obs.columns:
156
+ if str(obs[k].dtype) == "category":
157
+ obs[k] = obs[k].astype(pd.CategoricalDtype(categories[k]))
158
+ return obs
159
+
160
+ ###############################################################################
161
+ ## Misc. Helper functions
162
+ ###############################################################################
163
+
164
+ def _to_numpy(*tensors):
165
+ """Convenience helper: detach, move to CPU, and convert tensors to numpy arrays."""
166
+ return tuple(t.detach().cpu().numpy() for t in tensors)
@@ -0,0 +1,140 @@
1
+ from .kwargs import DEFAULT_ALLOWED_KWARGS, _filter_kwargs
2
+ from .loader import adata_loader
3
+ from anndata import AnnData
4
+ from typing import Union, Dict, Optional, Tuple
5
+ import pandas as pd
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class Marginal:
12
+ def __init__(self, formula: Union[Dict, str]):
13
+ self.formula = formula
14
+ self.feature_dims = None
15
+ self.loader = None
16
+ self.n_outcomes = None
17
+ self.predict = None
18
+ self.predictor_names = None
19
+ self.parameters = None
20
+
21
+ def setup_data(self, adata: AnnData, batch_size: int = 1024, **kwargs):
22
+ """Set up the dataloader for the AnnData object."""
23
+ # keep a reference to the AnnData for later use (e.g., var_names)
24
+ self.adata = adata
25
+ self.loader = adata_loader(adata, self.formula, batch_size=batch_size, **kwargs)
26
+ X_batch, obs_batch = next(iter(self.loader))
27
+ self.n_outcomes = X_batch.shape[1]
28
+ self.feature_dims = {k: v.shape[1] for k, v in obs_batch.items()}
29
+ self.predictor_names = self.loader.dataset.predictor_names
30
+
31
+ def fit(self, **kwargs):
32
+ """Fit the marginal predictor"""
33
+ if self.predict is None:
34
+ self.setup_optimizer(**kwargs)
35
+ trainer_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['trainer'])
36
+ trainer = pl.Trainer(**trainer_kwargs)
37
+ trainer.fit(self.predict, train_dataloaders=self.loader)
38
+ self.parameters = self.format_parameters()
39
+
40
+ def setup_optimizer(self, **kwargs):
41
+ raise NotImplementedError
42
+
43
+ def likelihood(self, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
44
+ """Compute the (negative) log-likelihood or loss for a batch.
45
+ """
46
+ raise NotImplementedError
47
+
48
+ def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
49
+ """Invert pseudoobservations."""
50
+ raise NotImplementedError
51
+
52
+ def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]):
53
+ """Uniformize using learned CDF.
54
+ """
55
+ raise NotImplementedError
56
+
57
+ def format_parameters(self):
58
+ """Convert fitted coefficient tensors into pandas DataFrames.
59
+
60
+ Returns:
61
+ dict: mapping from parameter name -> pandas.DataFrame with rows
62
+ corresponding to predictor column names (from
63
+ `self.predictor_names[param]`) and columns corresponding to
64
+ `self.adata.var_names` (gene names). The values are moved to
65
+ CPU and converted to numpy floats.
66
+ """
67
+ var_names = list(self.adata.var_names)
68
+
69
+ dfs = {}
70
+ for param, tensor in self.predict.coefs.items():
71
+ coef_np = tensor.detach().cpu().numpy()
72
+ row_names = list(self.predictor_names[param])
73
+ dfs[param] = pd.DataFrame(coef_np, index=row_names, columns=var_names)
74
+ return dfs
75
+
76
+ def num_params(self):
77
+ """Return the number of parameters."""
78
+ if self.predict is None:
79
+ return 0
80
+ return sum(p.numel() for p in self.predict.parameters() if p.requires_grad)
81
+
82
+
83
+ class GLMPredictor(pl.LightningModule):
84
+ """GLM-style predictor with arbitrary named parameters.
85
+
86
+ Args:
87
+ n_outcomes: number of model outputs (e.g. genes)
88
+ feature_dims: mapping from param name -> number of covariate features
89
+ link_fns: optional mapping from param name -> callable(link) applied to linear predictor
90
+
91
+ The module will create one coefficient matrix per named parameter with shape
92
+ (n_features_for_param, n_outcomes) and expose them as Parameters under
93
+ `self.coefs[param_name]`.
94
+ """
95
+ def __init__(
96
+ self,
97
+ n_outcomes: int,
98
+ feature_dims: Dict[str, int],
99
+ link_fns: Dict[str, callable] = None,
100
+ loss_fn: Optional[callable] = None,
101
+ optimizer_class: Optional[callable] = torch.optim.Adam,
102
+ optimizer_kwargs: Optional[Dict] = None,
103
+ ):
104
+ super().__init__()
105
+ self.n_outcomes = int(n_outcomes)
106
+ self.feature_dims = dict(feature_dims)
107
+ self.param_names = list(self.feature_dims.keys())
108
+
109
+ # create default link functions and parameter matrices
110
+ self.link_fns = link_fns or {k: torch.exp for k in self.param_names}
111
+ self.coefs = nn.ParameterDict()
112
+ for key, dim in self.feature_dims.items():
113
+ self.coefs[key] = nn.Parameter(torch.zeros(dim, self.n_outcomes))
114
+
115
+ # optimization parameters
116
+ self.reset_parameters()
117
+ self.loss_fn = loss_fn
118
+ self.optimizer_class = optimizer_class
119
+ self.optimizer_kwargs = optimizer_kwargs
120
+
121
+ def reset_parameters(self):
122
+ for p in self.coefs.values():
123
+ nn.init.normal_(p, mean=0.0, std=1e-2)
124
+
125
+ def forward(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
126
+ out = {}
127
+ for name in self.param_names:
128
+ x_beta = obs_dict[name] @ self.coefs[name]
129
+ link = self.link_fns.get(name, torch.exp)
130
+ out[name] = link(x_beta)
131
+ return out
132
+
133
+ def training_step(self, batch):
134
+ loss = self.loss_fn(batch)
135
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
136
+ return loss
137
+
138
+ def configure_optimizers(self, **kwargs):
139
+ optimizer_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['optimizer'])
140
+ return self.optimizer_class(self.parameters(), **optimizer_kwargs)