scdesigner 0.0.5__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. scdesigner/base/__init__.py +8 -0
  2. scdesigner/base/copula.py +416 -0
  3. scdesigner/base/marginal.py +391 -0
  4. scdesigner/base/simulator.py +59 -0
  5. scdesigner/copulas/__init__.py +8 -0
  6. scdesigner/copulas/standard_copula.py +645 -0
  7. scdesigner/datasets/__init__.py +5 -0
  8. scdesigner/datasets/pancreas.py +39 -0
  9. scdesigner/distributions/__init__.py +19 -0
  10. scdesigner/{minimal → distributions}/bernoulli.py +42 -14
  11. scdesigner/distributions/gaussian.py +114 -0
  12. scdesigner/distributions/negbin.py +121 -0
  13. scdesigner/distributions/negbin_irls.py +72 -0
  14. scdesigner/distributions/negbin_irls_funs.py +456 -0
  15. scdesigner/distributions/poisson.py +88 -0
  16. scdesigner/{minimal → distributions}/zero_inflated_negbin.py +39 -10
  17. scdesigner/distributions/zero_inflated_poisson.py +103 -0
  18. scdesigner/simulators/__init__.py +24 -28
  19. scdesigner/simulators/composite.py +239 -0
  20. scdesigner/simulators/positive_nonnegative_matrix_factorization.py +477 -0
  21. scdesigner/simulators/scd3.py +486 -0
  22. scdesigner/transform/__init__.py +8 -6
  23. scdesigner/{minimal → transform}/transform.py +1 -1
  24. scdesigner/{minimal → utils}/kwargs.py +4 -1
  25. {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/METADATA +1 -1
  26. scdesigner-0.0.10.dist-info/RECORD +28 -0
  27. {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/WHEEL +1 -1
  28. scdesigner/data/__init__.py +0 -16
  29. scdesigner/data/formula.py +0 -137
  30. scdesigner/data/group.py +0 -123
  31. scdesigner/data/sparse.py +0 -39
  32. scdesigner/diagnose/__init__.py +0 -65
  33. scdesigner/diagnose/aic_bic.py +0 -119
  34. scdesigner/diagnose/plot.py +0 -242
  35. scdesigner/estimators/__init__.py +0 -32
  36. scdesigner/estimators/bernoulli.py +0 -85
  37. scdesigner/estimators/gaussian.py +0 -121
  38. scdesigner/estimators/gaussian_copula_factory.py +0 -367
  39. scdesigner/estimators/glm_factory.py +0 -75
  40. scdesigner/estimators/negbin.py +0 -153
  41. scdesigner/estimators/pnmf.py +0 -160
  42. scdesigner/estimators/poisson.py +0 -124
  43. scdesigner/estimators/zero_inflated_negbin.py +0 -195
  44. scdesigner/estimators/zero_inflated_poisson.py +0 -85
  45. scdesigner/format/__init__.py +0 -4
  46. scdesigner/format/format.py +0 -20
  47. scdesigner/format/print.py +0 -30
  48. scdesigner/minimal/__init__.py +0 -17
  49. scdesigner/minimal/composite.py +0 -119
  50. scdesigner/minimal/copula.py +0 -205
  51. scdesigner/minimal/formula.py +0 -23
  52. scdesigner/minimal/gaussian.py +0 -65
  53. scdesigner/minimal/loader.py +0 -211
  54. scdesigner/minimal/marginal.py +0 -154
  55. scdesigner/minimal/negbin.py +0 -73
  56. scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -231
  57. scdesigner/minimal/scd3.py +0 -96
  58. scdesigner/minimal/scd3_instances.py +0 -50
  59. scdesigner/minimal/simulator.py +0 -25
  60. scdesigner/minimal/standard_copula.py +0 -383
  61. scdesigner/predictors/__init__.py +0 -15
  62. scdesigner/predictors/bernoulli.py +0 -9
  63. scdesigner/predictors/gaussian.py +0 -16
  64. scdesigner/predictors/negbin.py +0 -17
  65. scdesigner/predictors/poisson.py +0 -12
  66. scdesigner/predictors/zero_inflated_negbin.py +0 -18
  67. scdesigner/predictors/zero_inflated_poisson.py +0 -18
  68. scdesigner/samplers/__init__.py +0 -23
  69. scdesigner/samplers/bernoulli.py +0 -27
  70. scdesigner/samplers/gaussian.py +0 -25
  71. scdesigner/samplers/glm_factory.py +0 -103
  72. scdesigner/samplers/negbin.py +0 -25
  73. scdesigner/samplers/poisson.py +0 -25
  74. scdesigner/samplers/zero_inflated_negbin.py +0 -40
  75. scdesigner/samplers/zero_inflated_poisson.py +0 -16
  76. scdesigner/simulators/composite_regressor.py +0 -72
  77. scdesigner/simulators/glm_simulator.py +0 -167
  78. scdesigner/simulators/pnmf_regression.py +0 -61
  79. scdesigner/transform/amplify.py +0 -14
  80. scdesigner/transform/mask.py +0 -33
  81. scdesigner/transform/nullify.py +0 -25
  82. scdesigner/transform/split.py +0 -23
  83. scdesigner/transform/substitute.py +0 -14
  84. scdesigner-0.0.5.dist-info/RECORD +0 -66
@@ -1,119 +0,0 @@
1
- from .loader import obs_loader
2
- from .scd3 import SCD3Simulator
3
- from .standard_copula import StandardCopula
4
- from anndata import AnnData
5
- from typing import Dict, Optional, List
6
- import numpy as np
7
- import torch
8
-
9
- class CompositeCopula(SCD3Simulator):
10
- def __init__(self, marginals: List,
11
- copula_formula: Optional[str] = None) -> None:
12
- self.marginals = marginals
13
- self.copula = StandardCopula(copula_formula)
14
- self.template = None
15
- self.parameters = None
16
- self.merged_formula = None
17
-
18
- def fit(
19
- self,
20
- adata: AnnData,
21
- **kwargs):
22
- """Fit the simulator"""
23
- self.template = adata
24
- merged_formula = {}
25
-
26
- # fit each marginal model
27
- for m in range(len(self.marginals)):
28
- self.marginals[m][1].setup_data(adata[:, self.marginals[m][0]], **kwargs)
29
- self.marginals[m][1].setup_optimizer(**kwargs)
30
- self.marginals[m][1].fit(**kwargs)
31
-
32
- # prepare formula for copula loader
33
- f = self.marginals[m][1].formula
34
- prefixed_f = {f"group{m}_{k}": v for k, v in f.items()}
35
- merged_formula = merged_formula | prefixed_f
36
-
37
- # copula simulator
38
- self.merged_formula = merged_formula
39
- self.copula.setup_data(adata, merged_formula, **kwargs)
40
- self.copula.fit(self.merged_uniformize, **kwargs)
41
- self.parameters = {
42
- "marginal": [m[1].parameters for m in self.marginals],
43
- "copula": self.copula.parameters
44
- }
45
-
46
- def merged_uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
47
- """Produce a merged uniformized matrix for all marginals.
48
-
49
- Delegates to each marginal's `uniformize` method and places the
50
- result into the columns of a full matrix according to the variable
51
- selection given in `self.marginals[m][0]`.
52
- """
53
- y_np = y.detach().cpu().numpy()
54
- u = np.empty_like(y_np, dtype=float)
55
-
56
- for m in range(len(self.marginals)):
57
- sel = self.marginals[m][0]
58
- ix = _var_indices(sel, self.template)
59
-
60
- # remove the `group{m}_` prefix we used to distinguish the marginals
61
- prefix = f"group{m}_"
62
- cur_x = {k.removeprefix(prefix): v if k.startswith(prefix) else v for k, v in x.items()}
63
-
64
- # slice the subset of y for this marginal and call its uniformize
65
- y_sub = torch.from_numpy(y_np[:, ix])
66
- u[:, ix] = self.marginals[m][1].uniformize(y_sub, cur_x)
67
- return torch.from_numpy(u)
68
-
69
- def predict(self, obs=None, batch_size: int = 1000, **kwargs):
70
- """Predict from an obs dataframe"""
71
- # prepare an internal data loader for this obs
72
- if obs is None:
73
- obs = self.template.obs
74
- loader = obs_loader(
75
- obs,
76
- self.merged_formula,
77
- batch_size=batch_size,
78
- **kwargs
79
- )
80
-
81
- # prepare per-marginal collectors
82
- n_marginals = len(self.marginals)
83
- local_pred = [[] for _ in range(n_marginals)]
84
-
85
- # for each batch, call each marginal's predict on its subset of x
86
- for _, x_dict in loader:
87
- for m in range(n_marginals):
88
- prefix = f"group{m}_"
89
- # build cur_x where prefixed keys are unprefixed for the marginal
90
- cur_x = {k.removeprefix(prefix): v for k, v in x_dict.items()}
91
- params = self.marginals[m][1].predict(cur_x)
92
- local_pred[m].append(params)
93
-
94
- # merge batch-wise parameter dicts for each marginal and return
95
- results = []
96
- for m in range(n_marginals):
97
- parts = local_pred[m]
98
- keys = list(parts[0].keys())
99
- results.append({k: torch.cat([d[k] for d in parts]).detach().cpu().numpy() for k in keys})
100
-
101
- return results
102
-
103
-
104
- def _var_indices(sel, adata: AnnData) -> np.ndarray:
105
- """Return integer indices of `sel` within `adata.var_names`.
106
-
107
- Expected use: `sel` is a list (or tuple) of variable names (strings).
108
- """
109
- # If sel is a single string, make it a list so we return consistent shape
110
- single_string = False
111
- if isinstance(sel, str):
112
- sel = [sel]
113
- single_string = True
114
-
115
- idx = np.asarray(adata.var_names.get_indexer(sel), dtype=int)
116
- if (idx < 0).any():
117
- missing = [s for s, i in zip(sel, idx) if i < 0]
118
- raise KeyError(f"Variables not found in adata.var_names: {missing}")
119
- return idx if not single_string else idx.reshape(-1)
@@ -1,205 +0,0 @@
1
- from typing import Dict, Callable, Tuple
2
- import torch
3
- from anndata import AnnData
4
- from .loader import adata_loader
5
- from abc import ABC, abstractmethod
6
- import numpy as np
7
- import pandas as pd
8
- from typing import Optional, Union
9
- class Copula(ABC):
10
- def __init__(self, formula: str, **kwargs):
11
- self.formula = formula
12
- self.loader = None
13
- self.n_outcomes = None
14
- self.parameters = None # Should be a dictionary of CovarianceStructure objects
15
-
16
- def setup_data(self, adata: AnnData, marginal_formula: Dict[str, str], batch_size: int = 1024, **kwargs):
17
- self.adata = adata
18
- self.formula = self.formula | marginal_formula
19
- self.loader = adata_loader(adata, self.formula, batch_size=batch_size, **kwargs)
20
- X_batch, _ = next(iter(self.loader))
21
- self.n_outcomes = X_batch.shape[1]
22
-
23
- def decorrelate(self, row_pattern: str, col_pattern: str, group: Union[str, list, None] = None):
24
- """Decorrelate the covariance matrix for the given row and column patterns.
25
-
26
- Args:
27
- row_pattern (str): The regex pattern for the row names to match.
28
- col_pattern (str): The regex pattern for the column names to match.
29
- group (Union[str, list, None]): The group or groups to apply the transformation to. If None, the transformation is applied to all groups.
30
- """
31
- if group is None:
32
- for g in self.groups:
33
- self.parameters[g].decorrelate(row_pattern, col_pattern)
34
- elif isinstance(group, str):
35
- self.parameters[group].decorrelate(row_pattern, col_pattern)
36
- else:
37
- for g in group:
38
- self.parameters[g].decorrelate(row_pattern, col_pattern)
39
-
40
- def correlate(self, factor: float, row_pattern: str, col_pattern: str, group: Union[str, list, None] = None):
41
- """Multiply selected off-diagonal entries by factor.
42
-
43
- Args:
44
- row_pattern (str): The regex pattern for the row names to match.
45
- col_pattern (str): The regex pattern for the column names to match.
46
- factor (float): The factor to multiply the off-diagonal entries by.
47
- group (Union[str, list, None]): The group or groups to apply the transformation to. If None, the transformation is applied to all groups.
48
- """
49
- if group is None:
50
- for g in self.groups:
51
- self.parameters[g].correlate(row_pattern, col_pattern, factor)
52
- elif isinstance(group, str):
53
- self.parameters[group].correlate(row_pattern, col_pattern, factor)
54
- else:
55
- for g in group:
56
- self.parameters[g].correlate(row_pattern, col_pattern, factor)
57
-
58
- @abstractmethod
59
- def fit(self, uniformizer: Callable, **kwargs):
60
- raise NotImplementedError
61
-
62
- @abstractmethod
63
- def pseudo_obs(self, x_dict: Dict):
64
- raise NotImplementedError
65
-
66
- @abstractmethod
67
- def likelihood(self, uniformizer: Callable, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
68
- raise NotImplementedError
69
-
70
- @abstractmethod
71
- def num_params(self, **kwargs):
72
- raise NotImplementedError
73
-
74
- # @abstractmethod
75
- # def format_parameters(self):
76
- # raise NotImplementedError
77
-
78
- class CovarianceStructure:
79
- """
80
- Efficient storage for covariance matrices in copula-based gene expression modeling.
81
-
82
- This class provides memory-efficient storage for covariance information by storing
83
- either a full covariance matrix or a block matrix with diagonal variances for
84
- remaining genes. This enables fast copula estimation and sampling for large
85
- gene expression datasets.
86
-
87
-
88
-
89
- Attributes
90
- ----------
91
- cov : pd.DataFrame
92
- Covariance matrix for modeled genes with gene names as index/columns
93
- modeled_indices : np.ndarray
94
- Indices of modeled genes in original ordering
95
- remaining_var : pd.Series or None
96
- Diagonal variances for remaining genes, None if full matrix stored
97
- remaining_indices : np.ndarray or None
98
- Indices of remaining genes in original ordering
99
- num_modeled_genes : int
100
- Number of modeled genes
101
- num_remaining_genes : int
102
- Number of remaining genes (0 if full matrix stored)
103
- total_genes : int
104
- Total number of genes
105
- """
106
-
107
- def __init__(self, cov: np.ndarray,
108
- modeled_names: pd.Index,
109
- modeled_indices: Optional[np.ndarray] = None,
110
- remaining_var: Optional[np.ndarray] = None,
111
- remaining_indices: Optional[np.ndarray] = None,
112
- remaining_names: Optional[pd.Index] = None):
113
- """initialize a CovarianceStructure object.
114
-
115
- Args:
116
- cov (np.ndarray): Covariance matrix for modeled genes, shape (n_modeled_genes, n_modeled_genes)
117
- modeled_names (pd.Index): Gene names for the modeled genes
118
- modeled_indices (Optional[np.ndarray], optional): Indices of modeled genes in original ordering. Defaults to sequential indices.
119
- remaining_var (Optional[np.ndarray], optional): Diagonal variances for remaining genes, shape (n_remaining_genes,)
120
- remaining_indices (Optional[np.ndarray], optional): Indices of remaining genes in original ordering
121
- remaining_names (Optional[pd.Index], optional): Gene names for remaining genes
122
- """
123
- self.cov = pd.DataFrame(cov, index=modeled_names, columns=modeled_names)
124
-
125
- if modeled_indices is not None:
126
- self.modeled_indices = modeled_indices
127
- else:
128
- self.modeled_indices = np.arange(len(modeled_names))
129
-
130
- if remaining_var is not None:
131
- self.remaining_var = pd.Series(remaining_var, index=remaining_names)
132
- else:
133
- self.remaining_var = None
134
-
135
- self.remaining_indices = remaining_indices
136
- self.num_modeled_genes = len(modeled_names)
137
- self.num_remaining_genes = len(remaining_indices) if remaining_indices is not None else 0
138
- self.total_genes = self.num_modeled_genes + self.num_remaining_genes
139
-
140
- def __repr__(self):
141
- if self.remaining_var is None:
142
- return self.cov.__repr__()
143
- else:
144
- return f"CovarianceStructure(modeled_genes={self.num_modeled_genes}, \
145
- total_genes={self.total_genes})"
146
-
147
- def _repr_html_(self):
148
- """Jupyter Notebook display"""
149
- if self.remaining_var is None:
150
- return self.cov._repr_html_()
151
- else:
152
- html = f"<b>CovarianceStructure:</b> {self.num_modeled_genes} modeled genes, {self.total_genes} total<br>"
153
- html += "<h4>Modeled Covariance Matrix</h4>" + self.cov._repr_html_()
154
- html += "<h4>Remaining Gene Variances</h4>" + self.remaining_var.to_frame("variance").T._repr_html_()
155
- return html
156
-
157
- def decorrelate(self, row_pattern: str, col_pattern: str):
158
- """Decorrelate the covariance matrix for the given row and column patterns.
159
- """
160
- from .transform import data_frame_mask
161
- m1 = data_frame_mask(self.cov, ".", col_pattern)
162
- m2 = data_frame_mask(self.cov, row_pattern, ".")
163
- mask = (m1 | m2)
164
- np.fill_diagonal(mask, False)
165
- self.cov.values[mask] = 0
166
-
167
- def correlate(self, row_pattern: str, col_pattern: str, factor: float):
168
- """Multiply selected off-diagonal entries by factor.
169
-
170
- Args:
171
- row_pattern (str): The regex pattern for the row names to match.
172
- col_pattern (str): The regex pattern for the column names to match.
173
- factor (float): The factor to multiply the off-diagonal entries by.
174
- """
175
- from .transform import data_frame_mask
176
- m1 = data_frame_mask(self.cov, ".", col_pattern)
177
- m2 = data_frame_mask(self.cov, row_pattern, ".")
178
- mask = (m1 | m2)
179
- np.fill_diagonal(mask, False)
180
- self.cov.values[mask] = self.cov.values[mask] * factor
181
-
182
- @property
183
- def shape(self):
184
- return (self.total_genes, self.total_genes)
185
-
186
- def to_full_matrix(self):
187
- """
188
- Convert to full covariance matrix for compatibility/debugging.
189
- Returns:
190
- --------
191
- np.ndarray : Full covariance matrix with shape (total_genes, total_genes)
192
- """
193
- if self.remaining_var is None:
194
- return self.cov.values
195
- else:
196
- full_cov = np.zeros((self.total_genes, self.total_genes))
197
-
198
- # Fill in top-k block
199
- ix_modeled = np.ix_(self.modeled_indices, self.modeled_indices)
200
- full_cov[ix_modeled] = self.cov.values
201
-
202
- # Fill in diagonal for remaining genes
203
- full_cov[self.remaining_indices, self.remaining_indices] = self.remaining_var.values
204
-
205
- return full_cov
@@ -1,23 +0,0 @@
1
- from typing import Union
2
- import warnings
3
-
4
- def standardize_formula(formula: Union[str, dict], allowed_keys = None):
5
- # The first element of allowed_keys should be the name of default parameter
6
- if allowed_keys is None:
7
- raise ValueError("Internal error: allowed_keys must be specified")
8
- formula = {allowed_keys[0]: formula} if isinstance(formula, str) else formula
9
-
10
- formula_keys = set(formula.keys())
11
- allowed_keys = set(allowed_keys)
12
-
13
- if not formula_keys & allowed_keys:
14
- raise ValueError(f"formula must have at least one of the following keys: {allowed_keys}")
15
-
16
- if extra_keys := formula_keys - allowed_keys:
17
- warnings.warn(
18
- f"Invalid formulas in dictionary will not be used: {extra_keys}",
19
- UserWarning,
20
- )
21
-
22
- formula.update({k: '~ 1' for k in allowed_keys - formula_keys})
23
- return formula
@@ -1,65 +0,0 @@
1
- from .formula import standardize_formula
2
- from .marginal import GLMPredictor, Marginal
3
- from .loader import _to_numpy
4
- from typing import Union, Dict, Optional
5
- import torch
6
- import numpy as np
7
- from scipy.stats import norm
8
-
9
- class Gaussian(Marginal):
10
- """Gaussian marginal estimator"""
11
- def __init__(self, formula: Union[Dict, str]):
12
- formula = standardize_formula(formula, allowed_keys=['mean', 'sdev'])
13
- super().__init__(formula)
14
-
15
- def setup_optimizer(
16
- self,
17
- optimizer_class: Optional[callable] = torch.optim.Adam,
18
- **optimizer_kwargs,
19
- ):
20
- if self.loader is None:
21
- raise RuntimeError("self.loader is not set (call setup_data first)")
22
-
23
- nll = lambda batch: -self.likelihood(batch).sum()
24
- link_fns = {"mean": lambda x: x}
25
- self.predict = GLMPredictor(
26
- n_outcomes=self.n_outcomes,
27
- feature_dims=self.feature_dims,
28
- link_fns=link_fns,
29
- loss_fn=nll,
30
- optimizer_class=optimizer_class,
31
- optimizer_kwargs=optimizer_kwargs
32
- )
33
-
34
- def likelihood(self, batch):
35
- """Compute the negative log-likelihood"""
36
- y, x = batch
37
- params = self.predict(x)
38
- mu = params.get("mean")
39
- sigma = params.get("sdev")
40
-
41
- # log likelihood for Gaussian
42
- log_likelihood = -0.5 * (torch.log(2 * torch.pi * sigma ** 2) + ((y - mu) ** 2) / (sigma ** 2))
43
- return log_likelihood
44
-
45
- def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
46
- """Invert pseudoobservations."""
47
- mu, sdev, u = self._local_params(x, u)
48
- y = norm(loc=mu, scale=sdev).ppf(u)
49
- return torch.from_numpy(y).float()
50
-
51
- def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6):
52
- """Return uniformized pseudo-observations for counts y given covariates x."""
53
- # cdf values using scipy's parameterization
54
- mu, sdev, y = self._local_params(x, y)
55
- u = norm.cdf(y, loc=mu, scale=sdev)
56
- u = np.clip(u, epsilon, 1 - epsilon)
57
- return torch.from_numpy(u).float()
58
-
59
- def _local_params(self, x, y=None):
60
- params = self.predict(x)
61
- mu = params.get('mean')
62
- sdev = params.get('sdev')
63
- if y is None:
64
- return _to_numpy(mu, sdev)
65
- return _to_numpy(mu, sdev, y)
@@ -1,211 +0,0 @@
1
- from .kwargs import DEFAULT_ALLOWED_KWARGS, _filter_kwargs
2
- from anndata import AnnData
3
- from formulaic import model_matrix
4
- from torch.utils.data import Dataset, DataLoader
5
- from typing import Dict
6
- import numpy as np
7
- import pandas as pd
8
- import scipy.sparse
9
- import torch
10
-
11
- def get_device():
12
- """Detect and return the best available device (MPS, CUDA, or CPU)."""
13
- if torch.backends.mps.is_available():
14
- return torch.device("mps")
15
- elif torch.cuda.is_available():
16
- return torch.device("cuda")
17
- else:
18
- return torch.device("cpu")
19
-
20
-
21
- class PreloadedDataset(Dataset):
22
- """Dataset that assumes x and y are both fully in memory."""
23
- def __init__(self, y_tensor, x_tensors, predictor_names):
24
- self.y = y_tensor
25
- self.x = x_tensors
26
- self.predictor_names = predictor_names
27
-
28
- def __len__(self):
29
- return len(self.y)
30
-
31
- def __getitem__(self, idx):
32
- return self.y[idx], {k: v[idx] for k, v in self.x.items()}
33
-
34
- class AnnDataDataset(Dataset):
35
- """Simple PyTorch Dataset for AnnData objects.
36
-
37
- Supports optional chunked loading for backed AnnData objects. When
38
- `chunk_size` is provided, the dataset will load contiguous slices
39
- of rows (of size `chunk_size`) into memory once and serve individual
40
- rows from that cached chunk. Chunks are moved to device for faster access.
41
- """
42
- def __init__(self, adata: AnnData, formula: Dict[str, str], chunk_size: int):
43
- self.adata = adata
44
- self.formula = formula
45
- self.chunk_size = chunk_size
46
- self.device = get_device()
47
-
48
- # keeping track of covariate-related information
49
- self.obs_levels = categories(self.adata.obs)
50
- self.obs_matrices = {}
51
- self.predictor_names = None
52
-
53
- # Internal cache for the currently loaded chunk
54
- self._chunk: AnnData | None = None
55
- self._chunk_X = None
56
- self._chunk_start = 0
57
-
58
- def __len__(self):
59
- return len(self.adata)
60
-
61
- def __getitem__(self, idx):
62
- """Returns (X, obs) for the given index.
63
-
64
- If `chunk_size` was specified the dataset will load a chunk
65
- containing `idx` into memory (if not already cached) and
66
- index into that chunk.
67
- """
68
- self._ensure_chunk_loaded(idx)
69
- local_idx = idx - self._chunk_start
70
-
71
- # Get obs data from GPU-cached matrices
72
- obs_dict = {}
73
- for key in self.formula.keys():
74
- obs_dict[key] = self.obs_matrices[key][local_idx: local_idx + 1]
75
- return self._chunk_X[local_idx], obs_dict
76
-
77
- def _ensure_chunk_loaded(self, idx: int) -> None:
78
- """Load the chunk that contains `idx` into the internal cache."""
79
- start = (idx // self.chunk_size) * self.chunk_size
80
- end = min(start + self.chunk_size, len(self.adata))
81
-
82
- if (self._chunk is None) or not (self._chunk_start <= idx < self._chunk_start + len(self._chunk)):
83
- # load the next chunk into memory
84
- chunk = self.adata[start:end]
85
- if getattr(chunk, 'isbacked', False):
86
- chunk = chunk.to_memory()
87
- self._chunk = chunk
88
- self._chunk_start = start
89
-
90
- # Move chunk to GPU
91
- X = chunk.X
92
- if hasattr(X, 'toarray'):
93
- X = X.toarray()
94
- self._chunk_X = torch.tensor(X, dtype=torch.float32).to(self.device)
95
-
96
- # Compute model matrices for this chunk's `obs` and move to GPU
97
- obs_coded_chunk = code_levels(self._chunk.obs.copy(), self.obs_levels)
98
- self.obs_matrices = {}
99
- predictor_names = {}
100
- for key, f in self.formula.items():
101
- mat = model_matrix(f, obs_coded_chunk)
102
- predictor_names [key] = list(mat.columns)
103
- self.obs_matrices[key] = torch.tensor(mat.values, dtype=torch.float32).to(self.device)
104
-
105
- # Capture predictor (column) names from the model matrices once.
106
- if self.predictor_names is None:
107
- self.predictor_names = predictor_names
108
-
109
-
110
- def adata_loader(
111
- adata: AnnData,
112
- formula: Dict[str, str],
113
- chunk_size: int = None,
114
- batch_size: int = 1024,
115
- shuffle: bool = False,
116
- num_workers: int = 0,
117
- **kwargs
118
- ) -> DataLoader:
119
- """Create a DataLoader from AnnData that returns batches of (X, obs)."""
120
- data_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['data'])
121
- device = get_device()
122
-
123
- # separate chunked from non-chunked cases
124
- if not getattr(adata, 'isbacked', False):
125
- dataset = _preloaded_adata(adata, formula, device)
126
- else:
127
- dataset = AnnDataDataset(adata, formula, chunk_size or 5000)
128
-
129
- return DataLoader(
130
- dataset,
131
- batch_size=batch_size,
132
- shuffle=shuffle,
133
- num_workers=num_workers,
134
- collate_fn=dict_collate_fn,
135
- **data_kwargs
136
- )
137
-
138
- def obs_loader(obs: pd.DataFrame, marginal_formula, **kwargs):
139
- adata = AnnData(X=np.zeros((len(obs), 1)), obs=obs)
140
- return adata_loader(
141
- adata,
142
- marginal_formula,
143
- **kwargs
144
- )
145
-
146
- ################################################################################
147
- ## Extraction of in-memory AnnData to PreloadedDataset
148
- ################################################################################
149
-
150
- def _preloaded_adata(adata: AnnData, formula: Dict[str, str], device: torch.device) -> PreloadedDataset:
151
- X = adata.X
152
- if scipy.sparse.issparse(X):
153
- X = X.toarray()
154
- y = torch.tensor(X, dtype=torch.float32).to(device)
155
-
156
- obs = code_levels(adata.obs.copy(), categories(adata.obs))
157
- x = {
158
- k: torch.tensor(model_matrix(f, obs).values, dtype=torch.float32).to(device)
159
- for k, f in formula.items()
160
- }
161
- predictor_names = {k: list(model_matrix(f, obs).columns) for k, f in formula.items()}
162
- return PreloadedDataset(y, x, predictor_names)
163
-
164
- ################################################################################
165
- ## Helper functions
166
- ################################################################################
167
-
168
- def dict_collate_fn(batch):
169
- """
170
- Custom collate function for handling dictionary obs tensors.
171
- """
172
- X_batch = torch.stack([item[0] for item in batch])
173
- obs_batch = [item[1] for item in batch]
174
-
175
- obs_dict = {}
176
- for key in obs_batch[0].keys():
177
- obs_dict[key] = torch.stack([obs[key] for obs in obs_batch])
178
- return X_batch, obs_dict
179
-
180
- def to_tensor(X):
181
- # If the tensor is 2D with second dim == 1, squeeze only the first
182
- # dim when appropriate (e.g. converting a single-row X to 1D samples)
183
- t = torch.tensor(X, dtype=torch.float32)
184
- if t.dim() == 2 and t.size(1) == 1:
185
- if t.size(0) == 1:
186
- return t.view(1)
187
- return t
188
- return t.squeeze()
189
-
190
- def categories(obs):
191
- levels = {}
192
- for k in obs.columns:
193
- obs_type = str(obs[k].dtype)
194
- if obs_type in ["category", "object"]:
195
- levels[k] = obs[k].unique()
196
- return levels
197
-
198
-
199
- def code_levels(obs, categories):
200
- for k in obs.columns:
201
- if str(obs[k].dtype) == "category":
202
- obs[k] = obs[k].astype(pd.CategoricalDtype(categories[k]))
203
- return obs
204
-
205
- ###############################################################################
206
- ## Misc. Helper functions
207
- ###############################################################################
208
-
209
- def _to_numpy(*tensors):
210
- """Convenience helper: detach, move to CPU, and convert tensors to numpy arrays."""
211
- return tuple(t.detach().cpu().numpy() for t in tensors)