scdesigner 0.0.4__tar.gz → 0.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scdesigner might be problematic. Click here for more details.
- {scdesigner-0.0.4 → scdesigner-0.0.5}/PKG-INFO +1 -2
- {scdesigner-0.0.4 → scdesigner-0.0.5}/pyproject.toml +1 -2
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/loader.py +85 -40
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/marginal.py +33 -24
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/standard_copula.py +49 -49
- {scdesigner-0.0.4 → scdesigner-0.0.5}/.gitignore +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/README.md +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/__init__.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/data/__init__.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/data/formula.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/data/group.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/data/sparse.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/diagnose/__init__.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/diagnose/aic_bic.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/diagnose/plot.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/__init__.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/bernoulli.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/gaussian.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/gaussian_copula_factory.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/glm_factory.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/negbin.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/pnmf.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/poisson.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/zero_inflated_negbin.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/zero_inflated_poisson.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/format/__init__.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/format/format.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/format/print.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/__init__.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/bernoulli.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/composite.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/copula.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/formula.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/gaussian.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/kwargs.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/negbin.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/scd3.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/scd3_instances.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/simulator.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/transform.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/zero_inflated_negbin.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/predictors/__init__.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/predictors/bernoulli.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/predictors/gaussian.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/predictors/negbin.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/predictors/poisson.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/predictors/zero_inflated_negbin.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/predictors/zero_inflated_poisson.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/__init__.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/bernoulli.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/gaussian.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/glm_factory.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/negbin.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/poisson.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/zero_inflated_negbin.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/zero_inflated_poisson.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/simulators/__init__.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/simulators/composite_regressor.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/simulators/glm_simulator.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/simulators/pnmf_regression.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/transform/__init__.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/transform/amplify.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/transform/mask.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/transform/nullify.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/transform/split.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/transform/substitute.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/tests/__init__.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/tests/test_negative_binomial.py +0 -0
- {scdesigner-0.0.4 → scdesigner-0.0.5}/tests/test_simulator.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: scdesigner
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.5
|
|
4
4
|
Summary: Interactive simulation for rigorous and transparent multi-omics analysis.
|
|
5
5
|
Project-URL: Homepage, https://github.com/krisrs1128/scDesigner/
|
|
6
6
|
Project-URL: Issues, https://github.com/krisrs1128/scDesigner/Issues/
|
|
@@ -11,7 +11,6 @@ Classifier: Programming Language :: Python :: 3
|
|
|
11
11
|
Requires-Python: >=3.8
|
|
12
12
|
Requires-Dist: anndata
|
|
13
13
|
Requires-Dist: formulaic
|
|
14
|
-
Requires-Dist: lightning
|
|
15
14
|
Requires-Dist: numpy
|
|
16
15
|
Requires-Dist: pandas
|
|
17
16
|
Requires-Dist: rich
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "scdesigner"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.5"
|
|
4
4
|
authors = [
|
|
5
5
|
{ name="Kris Sankaran", email="ksankaran@wisc.edu" },
|
|
6
6
|
]
|
|
@@ -15,7 +15,6 @@ classifiers = [
|
|
|
15
15
|
dependencies = [
|
|
16
16
|
"anndata",
|
|
17
17
|
"formulaic",
|
|
18
|
-
"lightning",
|
|
19
18
|
"numpy",
|
|
20
19
|
"pandas",
|
|
21
20
|
"rich",
|
|
@@ -5,21 +5,45 @@ from torch.utils.data import Dataset, DataLoader
|
|
|
5
5
|
from typing import Dict
|
|
6
6
|
import numpy as np
|
|
7
7
|
import pandas as pd
|
|
8
|
+
import scipy.sparse
|
|
8
9
|
import torch
|
|
9
10
|
|
|
11
|
+
def get_device():
|
|
12
|
+
"""Detect and return the best available device (MPS, CUDA, or CPU)."""
|
|
13
|
+
if torch.backends.mps.is_available():
|
|
14
|
+
return torch.device("mps")
|
|
15
|
+
elif torch.cuda.is_available():
|
|
16
|
+
return torch.device("cuda")
|
|
17
|
+
else:
|
|
18
|
+
return torch.device("cpu")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PreloadedDataset(Dataset):
|
|
22
|
+
"""Dataset that assumes x and y are both fully in memory."""
|
|
23
|
+
def __init__(self, y_tensor, x_tensors, predictor_names):
|
|
24
|
+
self.y = y_tensor
|
|
25
|
+
self.x = x_tensors
|
|
26
|
+
self.predictor_names = predictor_names
|
|
27
|
+
|
|
28
|
+
def __len__(self):
|
|
29
|
+
return len(self.y)
|
|
30
|
+
|
|
31
|
+
def __getitem__(self, idx):
|
|
32
|
+
return self.y[idx], {k: v[idx] for k, v in self.x.items()}
|
|
33
|
+
|
|
10
34
|
class AnnDataDataset(Dataset):
|
|
11
35
|
"""Simple PyTorch Dataset for AnnData objects.
|
|
12
36
|
|
|
13
37
|
Supports optional chunked loading for backed AnnData objects. When
|
|
14
38
|
`chunk_size` is provided, the dataset will load contiguous slices
|
|
15
39
|
of rows (of size `chunk_size`) into memory once and serve individual
|
|
16
|
-
rows from that cached chunk.
|
|
17
|
-
a per-row basis which is expensive for large backed files.
|
|
40
|
+
rows from that cached chunk. Chunks are moved to device for faster access.
|
|
18
41
|
"""
|
|
19
42
|
def __init__(self, adata: AnnData, formula: Dict[str, str], chunk_size: int):
|
|
20
43
|
self.adata = adata
|
|
21
44
|
self.formula = formula
|
|
22
45
|
self.chunk_size = chunk_size
|
|
46
|
+
self.device = get_device()
|
|
23
47
|
|
|
24
48
|
# keeping track of covariate-related information
|
|
25
49
|
self.obs_levels = categories(self.adata.obs)
|
|
@@ -28,6 +52,7 @@ class AnnDataDataset(Dataset):
|
|
|
28
52
|
|
|
29
53
|
# Internal cache for the currently loaded chunk
|
|
30
54
|
self._chunk: AnnData | None = None
|
|
55
|
+
self._chunk_X = None
|
|
31
56
|
self._chunk_start = 0
|
|
32
57
|
|
|
33
58
|
def __len__(self):
|
|
@@ -42,19 +67,12 @@ class AnnDataDataset(Dataset):
|
|
|
42
67
|
"""
|
|
43
68
|
self._ensure_chunk_loaded(idx)
|
|
44
69
|
local_idx = idx - self._chunk_start
|
|
45
|
-
adata_slice = self._chunk[local_idx]
|
|
46
|
-
|
|
47
|
-
# Get X data, accounting for potential sparse matrices
|
|
48
|
-
X = adata_slice.X
|
|
49
|
-
if hasattr(X, 'toarray'):
|
|
50
|
-
X = X.toarray()
|
|
51
70
|
|
|
52
|
-
# Get obs data
|
|
71
|
+
# Get obs data from GPU-cached matrices
|
|
53
72
|
obs_dict = {}
|
|
54
73
|
for key in self.formula.keys():
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
return to_tensor(X), obs_dict
|
|
74
|
+
obs_dict[key] = self.obs_matrices[key][local_idx: local_idx + 1]
|
|
75
|
+
return self._chunk_X[local_idx], obs_dict
|
|
58
76
|
|
|
59
77
|
def _ensure_chunk_loaded(self, idx: int) -> None:
|
|
60
78
|
"""Load the chunk that contains `idx` into the internal cache."""
|
|
@@ -69,36 +87,45 @@ class AnnDataDataset(Dataset):
|
|
|
69
87
|
self._chunk = chunk
|
|
70
88
|
self._chunk_start = start
|
|
71
89
|
|
|
72
|
-
#
|
|
73
|
-
|
|
90
|
+
# Move chunk to GPU
|
|
91
|
+
X = chunk.X
|
|
92
|
+
if hasattr(X, 'toarray'):
|
|
93
|
+
X = X.toarray()
|
|
94
|
+
self._chunk_X = torch.tensor(X, dtype=torch.float32).to(self.device)
|
|
95
|
+
|
|
96
|
+
# Compute model matrices for this chunk's `obs` and move to GPU
|
|
74
97
|
obs_coded_chunk = code_levels(self._chunk.obs.copy(), self.obs_levels)
|
|
75
98
|
self.obs_matrices = {}
|
|
99
|
+
predictor_names = {}
|
|
76
100
|
for key, f in self.formula.items():
|
|
77
|
-
|
|
101
|
+
mat = model_matrix(f, obs_coded_chunk)
|
|
102
|
+
predictor_names [key] = list(mat.columns)
|
|
103
|
+
self.obs_matrices[key] = torch.tensor(mat.values, dtype=torch.float32).to(self.device)
|
|
78
104
|
|
|
79
105
|
# Capture predictor (column) names from the model matrices once.
|
|
80
106
|
if self.predictor_names is None:
|
|
81
|
-
self.predictor_names =
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def adata_loader(
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
"""
|
|
107
|
+
self.predictor_names = predictor_names
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def adata_loader(
|
|
111
|
+
adata: AnnData,
|
|
112
|
+
formula: Dict[str, str],
|
|
113
|
+
chunk_size: int = None,
|
|
114
|
+
batch_size: int = 1024,
|
|
115
|
+
shuffle: bool = False,
|
|
116
|
+
num_workers: int = 0,
|
|
117
|
+
**kwargs
|
|
118
|
+
) -> DataLoader:
|
|
119
|
+
"""Create a DataLoader from AnnData that returns batches of (X, obs)."""
|
|
94
120
|
data_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['data'])
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
121
|
+
device = get_device()
|
|
122
|
+
|
|
123
|
+
# separate chunked from non-chunked cases
|
|
124
|
+
if not getattr(adata, 'isbacked', False):
|
|
125
|
+
dataset = _preloaded_adata(adata, formula, device)
|
|
126
|
+
else:
|
|
127
|
+
dataset = AnnDataDataset(adata, formula, chunk_size or 5000)
|
|
100
128
|
|
|
101
|
-
dataset = AnnDataDataset(adata, formula, chunk_size)
|
|
102
129
|
return DataLoader(
|
|
103
130
|
dataset,
|
|
104
131
|
batch_size=batch_size,
|
|
@@ -109,12 +136,30 @@ def adata_loader(adata: AnnData,
|
|
|
109
136
|
)
|
|
110
137
|
|
|
111
138
|
def obs_loader(obs: pd.DataFrame, marginal_formula, **kwargs):
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
139
|
+
adata = AnnData(X=np.zeros((len(obs), 1)), obs=obs)
|
|
140
|
+
return adata_loader(
|
|
141
|
+
adata,
|
|
142
|
+
marginal_formula,
|
|
143
|
+
**kwargs
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
################################################################################
|
|
147
|
+
## Extraction of in-memory AnnData to PreloadedDataset
|
|
148
|
+
################################################################################
|
|
149
|
+
|
|
150
|
+
def _preloaded_adata(adata: AnnData, formula: Dict[str, str], device: torch.device) -> PreloadedDataset:
|
|
151
|
+
X = adata.X
|
|
152
|
+
if scipy.sparse.issparse(X):
|
|
153
|
+
X = X.toarray()
|
|
154
|
+
y = torch.tensor(X, dtype=torch.float32).to(device)
|
|
155
|
+
|
|
156
|
+
obs = code_levels(adata.obs.copy(), categories(adata.obs))
|
|
157
|
+
x = {
|
|
158
|
+
k: torch.tensor(model_matrix(f, obs).values, dtype=torch.float32).to(device)
|
|
159
|
+
for k, f in formula.items()
|
|
160
|
+
}
|
|
161
|
+
predictor_names = {k: list(model_matrix(f, obs).columns) for k, f in formula.items()}
|
|
162
|
+
return PreloadedDataset(y, x, predictor_names)
|
|
118
163
|
|
|
119
164
|
################################################################################
|
|
120
165
|
## Helper functions
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
from .kwargs import DEFAULT_ALLOWED_KWARGS, _filter_kwargs
|
|
2
|
-
from .loader import adata_loader
|
|
2
|
+
from .loader import adata_loader, get_device
|
|
3
3
|
from anndata import AnnData
|
|
4
4
|
from typing import Union, Dict, Optional, Tuple
|
|
5
5
|
import pandas as pd
|
|
6
|
-
import pytorch_lightning as pl
|
|
7
6
|
import torch
|
|
8
7
|
import torch.nn as nn
|
|
9
8
|
from abc import ABC, abstractmethod
|
|
@@ -18,6 +17,7 @@ class Marginal(ABC):
|
|
|
18
17
|
self.predict = None
|
|
19
18
|
self.predictor_names = None
|
|
20
19
|
self.parameters = None
|
|
20
|
+
self.device = get_device()
|
|
21
21
|
|
|
22
22
|
def setup_data(self, adata: AnnData, batch_size: int = 1024, **kwargs):
|
|
23
23
|
"""Set up the dataloader for the AnnData object."""
|
|
@@ -29,13 +29,30 @@ class Marginal(ABC):
|
|
|
29
29
|
self.feature_dims = {k: v.shape[1] for k, v in obs_batch.items()}
|
|
30
30
|
self.predictor_names = self.loader.dataset.predictor_names
|
|
31
31
|
|
|
32
|
-
def fit(self, **kwargs):
|
|
33
|
-
"""Fit the marginal predictor"""
|
|
32
|
+
def fit(self, max_epochs: int = 100, **kwargs):
|
|
33
|
+
"""Fit the marginal predictor using vanilla PyTorch training loop."""
|
|
34
34
|
if self.predict is None:
|
|
35
35
|
self.setup_optimizer(**kwargs)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
36
|
+
|
|
37
|
+
for epoch in range(max_epochs):
|
|
38
|
+
epoch_loss, n_batches = 0.0, 0
|
|
39
|
+
|
|
40
|
+
for batch in self.loader:
|
|
41
|
+
y, x = batch
|
|
42
|
+
if y.device != self.device:
|
|
43
|
+
y = y.to(self.device)
|
|
44
|
+
x = {k: v.to(self.device) for k, v in x.items()}
|
|
45
|
+
|
|
46
|
+
self.predict.optimizer.zero_grad()
|
|
47
|
+
loss = self.predict.loss_fn((y, x))
|
|
48
|
+
loss.backward()
|
|
49
|
+
self.predict.optimizer.step()
|
|
50
|
+
|
|
51
|
+
epoch_loss += loss.item()
|
|
52
|
+
n_batches += 1
|
|
53
|
+
|
|
54
|
+
avg_loss = epoch_loss / n_batches
|
|
55
|
+
print(f"Epoch {epoch}/{max_epochs}, Loss: {avg_loss:.4f}", end='\r')
|
|
39
56
|
self.parameters = self.format_parameters()
|
|
40
57
|
|
|
41
58
|
def format_parameters(self):
|
|
@@ -85,7 +102,7 @@ class Marginal(ABC):
|
|
|
85
102
|
raise NotImplementedError
|
|
86
103
|
|
|
87
104
|
|
|
88
|
-
class GLMPredictor(
|
|
105
|
+
class GLMPredictor(nn.Module):
|
|
89
106
|
"""GLM-style predictor with arbitrary named parameters.
|
|
90
107
|
|
|
91
108
|
Args:
|
|
@@ -111,21 +128,22 @@ class GLMPredictor(pl.LightningModule):
|
|
|
111
128
|
self.feature_dims = dict(feature_dims)
|
|
112
129
|
self.param_names = list(self.feature_dims.keys())
|
|
113
130
|
|
|
114
|
-
# create default link functions and parameter matrices
|
|
115
131
|
self.link_fns = link_fns or {k: torch.exp for k in self.param_names}
|
|
116
132
|
self.coefs = nn.ParameterDict()
|
|
117
133
|
for key, dim in self.feature_dims.items():
|
|
118
134
|
self.coefs[key] = nn.Parameter(torch.zeros(dim, self.n_outcomes))
|
|
119
|
-
|
|
120
|
-
# optimization parameters
|
|
121
135
|
self.reset_parameters()
|
|
136
|
+
|
|
122
137
|
self.loss_fn = loss_fn
|
|
123
|
-
self.
|
|
124
|
-
|
|
138
|
+
self.to(get_device())
|
|
139
|
+
|
|
140
|
+
optimizer_kwargs = optimizer_kwargs or {}
|
|
141
|
+
filtered_kwargs = _filter_kwargs(optimizer_kwargs, DEFAULT_ALLOWED_KWARGS['optimizer'])
|
|
142
|
+
self.optimizer = optimizer_class(self.parameters(), **filtered_kwargs)
|
|
125
143
|
|
|
126
144
|
def reset_parameters(self):
|
|
127
145
|
for p in self.coefs.values():
|
|
128
|
-
nn.init.normal_(p, mean=0.0, std=1e-
|
|
146
|
+
nn.init.normal_(p, mean=0.0, std=1e-4)
|
|
129
147
|
|
|
130
148
|
def forward(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
131
149
|
out = {}
|
|
@@ -133,13 +151,4 @@ class GLMPredictor(pl.LightningModule):
|
|
|
133
151
|
x_beta = obs_dict[name] @ self.coefs[name]
|
|
134
152
|
link = self.link_fns.get(name, torch.exp)
|
|
135
153
|
out[name] = link(x_beta)
|
|
136
|
-
return out
|
|
137
|
-
|
|
138
|
-
def training_step(self, batch):
|
|
139
|
-
loss = self.loss_fn(batch)
|
|
140
|
-
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
|
141
|
-
return loss
|
|
142
|
-
|
|
143
|
-
def configure_optimizers(self, **kwargs):
|
|
144
|
-
optimizer_kwargs = _filter_kwargs(self.optimizer_kwargs, DEFAULT_ALLOWED_KWARGS['optimizer'])
|
|
145
|
-
return self.optimizer_class(self.parameters(), **optimizer_kwargs)
|
|
154
|
+
return out
|
|
@@ -51,16 +51,16 @@ class StandardCopula(Copula):
|
|
|
51
51
|
def fit(self, uniformizer: Callable, **kwargs):
|
|
52
52
|
"""
|
|
53
53
|
Fit the copula covariance model.
|
|
54
|
-
|
|
54
|
+
|
|
55
55
|
Args:
|
|
56
56
|
uniformizer (Callable): Function to convert data to uniform distribution
|
|
57
57
|
**kwargs: Additional arguments
|
|
58
58
|
top_k (int, optional): Use only top-k most expressed genes for covariance estimation.
|
|
59
59
|
If None, estimates full covariance for all genes.
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
Returns:
|
|
62
62
|
None: Stores fitted parameters in self.parameters as dict of CovarianceStructure objects.
|
|
63
|
-
|
|
63
|
+
|
|
64
64
|
Raises:
|
|
65
65
|
ValueError: If top_k is not a positive integer or exceeds n_outcomes
|
|
66
66
|
"""
|
|
@@ -76,11 +76,11 @@ class StandardCopula(Copula):
|
|
|
76
76
|
sorted_indices = np.argsort(gene_total_expression)
|
|
77
77
|
top_k_indices = sorted_indices[-top_k:]
|
|
78
78
|
remaining_indices = sorted_indices[:-top_k]
|
|
79
|
-
covariances = self._compute_block_covariance(uniformizer, top_k_indices,
|
|
79
|
+
covariances = self._compute_block_covariance(uniformizer, top_k_indices,
|
|
80
80
|
remaining_indices, top_k)
|
|
81
81
|
else:
|
|
82
82
|
covariances = self._compute_full_covariance(uniformizer)
|
|
83
|
-
|
|
83
|
+
|
|
84
84
|
self.parameters = covariances
|
|
85
85
|
|
|
86
86
|
def pseudo_obs(self, x_dict: Dict):
|
|
@@ -88,7 +88,7 @@ class StandardCopula(Copula):
|
|
|
88
88
|
# {"group1": [indices of group 1], "group2": [indices of group 2]}
|
|
89
89
|
# The initialization method ensures that x_dict will always have a "group" key.
|
|
90
90
|
group_data = x_dict.get("group")
|
|
91
|
-
memberships = group_data.numpy()
|
|
91
|
+
memberships = group_data.cpu().numpy()
|
|
92
92
|
group_ix = {g: np.where(memberships[:, self.group_col[g] == 1])[0] for g in self.groups}
|
|
93
93
|
|
|
94
94
|
# initialize the result
|
|
@@ -106,14 +106,14 @@ class StandardCopula(Copula):
|
|
|
106
106
|
def likelihood(self, uniformizer: Callable, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
|
|
107
107
|
"""
|
|
108
108
|
Compute likelihood of data given the copula model.
|
|
109
|
-
|
|
109
|
+
|
|
110
110
|
Args:
|
|
111
111
|
uniformizer (Callable): Function to convert expression data to uniform distribution
|
|
112
112
|
batch (Tuple[torch.Tensor, Dict[str, torch.Tensor]]): Data batch containing:
|
|
113
113
|
- Y (torch.Tensor): Expression data of shape (n_cells, n_genes)
|
|
114
114
|
- X_dict (Dict[str, torch.Tensor]): Covariates dict with keys as parameter names
|
|
115
115
|
and values as tensors of shape (n_cells, n_covariates)
|
|
116
|
-
|
|
116
|
+
|
|
117
117
|
Returns:
|
|
118
118
|
np.ndarray: Log-likelihood for each cell, shape (n_cells,)
|
|
119
119
|
"""
|
|
@@ -132,19 +132,19 @@ class StandardCopula(Copula):
|
|
|
132
132
|
group_ix = {g: np.where(memberships[:, self.group_col[g] == 1])[0] for g in self.groups}
|
|
133
133
|
|
|
134
134
|
ll = np.zeros(len(z))
|
|
135
|
-
|
|
135
|
+
|
|
136
136
|
for group, cov_struct in parameters.items():
|
|
137
137
|
ix = group_ix[group]
|
|
138
138
|
if len(ix) > 0:
|
|
139
139
|
z_modeled = z[ix][:, cov_struct.modeled_indices]
|
|
140
|
-
|
|
140
|
+
|
|
141
141
|
ll_modeled = multivariate_normal.logpdf(z_modeled,
|
|
142
|
-
np.zeros(cov_struct.num_modeled_genes),
|
|
142
|
+
np.zeros(cov_struct.num_modeled_genes),
|
|
143
143
|
cov_struct.cov.values)
|
|
144
144
|
if cov_struct.num_remaining_genes > 0:
|
|
145
145
|
z_remaining = z[ix][:, cov_struct.remaining_indices]
|
|
146
146
|
ll_remaining = norm.logpdf(z_remaining,
|
|
147
|
-
loc=0,
|
|
147
|
+
loc=0,
|
|
148
148
|
scale = np.sqrt(cov_struct.remaining_var.values))
|
|
149
149
|
else:
|
|
150
150
|
ll_remaining = 0
|
|
@@ -155,7 +155,7 @@ class StandardCopula(Copula):
|
|
|
155
155
|
S = self.parameters
|
|
156
156
|
per_group = [((S[g].num_modeled_genes * (S[g].num_modeled_genes - 1)) / 2) for g in self.groups]
|
|
157
157
|
return sum(per_group)
|
|
158
|
-
|
|
158
|
+
|
|
159
159
|
def _validate_parameters(self, **kwargs):
|
|
160
160
|
top_k = kwargs.get("top_k", None)
|
|
161
161
|
if top_k is not None:
|
|
@@ -166,14 +166,14 @@ class StandardCopula(Copula):
|
|
|
166
166
|
if top_k > self.n_outcomes:
|
|
167
167
|
raise ValueError(f"top_k ({top_k}) cannot exceed number of outcomes ({self.n_outcomes})")
|
|
168
168
|
return top_k
|
|
169
|
-
|
|
170
|
-
|
|
169
|
+
|
|
170
|
+
|
|
171
171
|
|
|
172
172
|
def _accumulate_top_k_stats(self, uniformizer:Callable, top_k_idx, rem_idx, top_k) \
|
|
173
|
-
-> Tuple[Dict[Union[str, int], np.ndarray],
|
|
174
|
-
Dict[Union[str, int], np.ndarray],
|
|
175
|
-
Dict[Union[str, int], np.ndarray],
|
|
176
|
-
Dict[Union[str, int], np.ndarray],
|
|
173
|
+
-> Tuple[Dict[Union[str, int], np.ndarray],
|
|
174
|
+
Dict[Union[str, int], np.ndarray],
|
|
175
|
+
Dict[Union[str, int], np.ndarray],
|
|
176
|
+
Dict[Union[str, int], np.ndarray],
|
|
177
177
|
Dict[Union[str, int], int]]:
|
|
178
178
|
"""Accumulate sufficient statistics for top-k covariance estimation.
|
|
179
179
|
|
|
@@ -198,7 +198,7 @@ class StandardCopula(Copula):
|
|
|
198
198
|
|
|
199
199
|
for y, x_dict in tqdm(self.loader, desc="Estimating top-k copula covariance"):
|
|
200
200
|
group_data = x_dict.get("group")
|
|
201
|
-
memberships = group_data.numpy()
|
|
201
|
+
memberships = group_data.cpu().numpy()
|
|
202
202
|
u = uniformizer(y, x_dict)
|
|
203
203
|
z = norm.ppf(u)
|
|
204
204
|
|
|
@@ -211,20 +211,20 @@ class StandardCopula(Copula):
|
|
|
211
211
|
n_g = mask.sum()
|
|
212
212
|
|
|
213
213
|
top_k_z, rem_z = z_g[:, top_k_idx], z_g[:, rem_idx]
|
|
214
|
-
|
|
214
|
+
|
|
215
215
|
top_k_sums[g] += top_k_z.sum(axis=0)
|
|
216
216
|
top_k_second_moments[g] += top_k_z.T @ top_k_z
|
|
217
|
-
|
|
217
|
+
|
|
218
218
|
rem_sums[g] += rem_z.sum(axis=0)
|
|
219
219
|
rem_second_moments[g] += (rem_z ** 2).sum(axis=0)
|
|
220
|
-
|
|
220
|
+
|
|
221
221
|
Ng[g] += n_g
|
|
222
222
|
|
|
223
223
|
return top_k_sums, top_k_second_moments, rem_sums, rem_second_moments, Ng
|
|
224
|
-
|
|
224
|
+
|
|
225
225
|
def _accumulate_full_stats(self, uniformizer:Callable) \
|
|
226
|
-
-> Tuple[Dict[Union[str, int], np.ndarray],
|
|
227
|
-
Dict[Union[str, int], np.ndarray],
|
|
226
|
+
-> Tuple[Dict[Union[str, int], np.ndarray],
|
|
227
|
+
Dict[Union[str, int], np.ndarray],
|
|
228
228
|
Dict[Union[str, int], int]]:
|
|
229
229
|
"""Accumulate sufficient statistics for full covariance estimation.
|
|
230
230
|
|
|
@@ -242,14 +242,14 @@ class StandardCopula(Copula):
|
|
|
242
242
|
|
|
243
243
|
for y, x_dict in tqdm(self.loader, desc="Estimating copula covariance"):
|
|
244
244
|
group_data = x_dict.get("group")
|
|
245
|
-
memberships = group_data.numpy()
|
|
246
|
-
|
|
245
|
+
memberships = group_data.cpu().numpy()
|
|
246
|
+
|
|
247
247
|
u = uniformizer(y, x_dict)
|
|
248
248
|
z = norm.ppf(u)
|
|
249
249
|
|
|
250
250
|
for g in self.groups:
|
|
251
251
|
mask = memberships[:, self.group_col[g]] == 1
|
|
252
|
-
|
|
252
|
+
|
|
253
253
|
if not np.any(mask):
|
|
254
254
|
continue
|
|
255
255
|
|
|
@@ -258,12 +258,12 @@ class StandardCopula(Copula):
|
|
|
258
258
|
|
|
259
259
|
second_moments[g] += z_g.T @ z_g
|
|
260
260
|
sums[g] += z_g.sum(axis=0)
|
|
261
|
-
|
|
261
|
+
|
|
262
262
|
Ng[g] += n_g
|
|
263
263
|
|
|
264
264
|
return sums, second_moments, Ng
|
|
265
|
-
|
|
266
|
-
def _compute_block_covariance(self, uniformizer:Callable,
|
|
265
|
+
|
|
266
|
+
def _compute_block_covariance(self, uniformizer:Callable,
|
|
267
267
|
top_k_idx: np.ndarray, rem_idx: np.ndarray, top_k: int) \
|
|
268
268
|
-> Dict[Union[str, int], CovarianceStructure]:
|
|
269
269
|
"""Compute the covariance matrix for the top-k and remaining genes.
|
|
@@ -300,7 +300,7 @@ class StandardCopula(Copula):
|
|
|
300
300
|
remaining_names=remaining_names
|
|
301
301
|
)
|
|
302
302
|
return covariance
|
|
303
|
-
|
|
303
|
+
|
|
304
304
|
def _compute_full_covariance(self, uniformizer:Callable) -> Dict[Union[str, int], CovarianceStructure]:
|
|
305
305
|
"""Compute the covariance matrix for the full genes.
|
|
306
306
|
|
|
@@ -327,7 +327,7 @@ class StandardCopula(Copula):
|
|
|
327
327
|
remaining_names=None
|
|
328
328
|
)
|
|
329
329
|
return covariance
|
|
330
|
-
|
|
330
|
+
|
|
331
331
|
def _fast_normal_pseudo_obs(self, n_samples: int, cov_struct: CovarianceStructure) -> np.ndarray:
|
|
332
332
|
"""Sample pseudo-observations from the covariance structure.
|
|
333
333
|
|
|
@@ -339,28 +339,28 @@ class StandardCopula(Copula):
|
|
|
339
339
|
np.ndarray: Pseudo-observations with shape (n_samples, total_genes)
|
|
340
340
|
"""
|
|
341
341
|
u = np.zeros((n_samples, cov_struct.total_genes))
|
|
342
|
-
|
|
342
|
+
|
|
343
343
|
z_modeled = np.random.multivariate_normal(
|
|
344
|
-
mean=np.zeros(cov_struct.num_modeled_genes),
|
|
345
|
-
cov=cov_struct.cov.values,
|
|
344
|
+
mean=np.zeros(cov_struct.num_modeled_genes),
|
|
345
|
+
cov=cov_struct.cov.values,
|
|
346
346
|
size=n_samples
|
|
347
347
|
)
|
|
348
|
-
|
|
348
|
+
|
|
349
349
|
z_remaining = np.random.normal(
|
|
350
|
-
loc=0,
|
|
351
|
-
scale=cov_struct.remaining_var.values ** 0.5,
|
|
350
|
+
loc=0,
|
|
351
|
+
scale=cov_struct.remaining_var.values ** 0.5,
|
|
352
352
|
size=(n_samples, cov_struct.num_remaining_genes)
|
|
353
353
|
)
|
|
354
|
-
|
|
354
|
+
|
|
355
355
|
normal_distn_modeled = norm(0, np.diag(cov_struct.cov.values) ** 0.5)
|
|
356
356
|
u[:, cov_struct.modeled_indices] = normal_distn_modeled.cdf(z_modeled)
|
|
357
|
-
|
|
357
|
+
|
|
358
358
|
normal_distn_remaining = norm(0, cov_struct.remaining_var.values ** 0.5)
|
|
359
359
|
u[:, cov_struct.remaining_indices] = normal_distn_remaining.cdf(z_remaining)
|
|
360
|
-
|
|
360
|
+
|
|
361
361
|
return u
|
|
362
|
-
|
|
363
|
-
def _normal_pseudo_obs(self, n_samples: int, cov_struct: CovarianceStructure) -> np.ndarray:
|
|
362
|
+
|
|
363
|
+
def _normal_pseudo_obs(self, n_samples: int, cov_struct: CovarianceStructure) -> np.ndarray:
|
|
364
364
|
"""Sample pseudo-observations from the covariance structure.
|
|
365
365
|
|
|
366
366
|
Args:
|
|
@@ -372,12 +372,12 @@ class StandardCopula(Copula):
|
|
|
372
372
|
"""
|
|
373
373
|
u = np.zeros((n_samples, cov_struct.total_genes))
|
|
374
374
|
z = np.random.multivariate_normal(
|
|
375
|
-
mean=np.zeros(cov_struct.total_genes),
|
|
376
|
-
cov=cov_struct.cov.values,
|
|
375
|
+
mean=np.zeros(cov_struct.total_genes),
|
|
376
|
+
cov=cov_struct.cov.values,
|
|
377
377
|
size=n_samples
|
|
378
378
|
)
|
|
379
|
-
|
|
379
|
+
|
|
380
380
|
normal_distn = norm(0, np.diag(cov_struct.cov.values) ** 0.5)
|
|
381
381
|
u = normal_distn.cdf(z)
|
|
382
|
-
|
|
382
|
+
|
|
383
383
|
return u
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|