scdesigner 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scdesigner might be problematic. Click here for more details.
- scdesigner-0.0.1/.gitignore +8 -0
- scdesigner-0.0.1/PKG-INFO +23 -0
- scdesigner-0.0.1/README.md +2 -0
- scdesigner-0.0.1/pyproject.toml +35 -0
- scdesigner-0.0.1/src/scdesigner/__init__.py +0 -0
- scdesigner-0.0.1/src/scdesigner/data/__init__.py +16 -0
- scdesigner-0.0.1/src/scdesigner/data/formula.py +137 -0
- scdesigner-0.0.1/src/scdesigner/data/group.py +123 -0
- scdesigner-0.0.1/src/scdesigner/data/sparse.py +39 -0
- scdesigner-0.0.1/src/scdesigner/diagnose/__init__.py +65 -0
- scdesigner-0.0.1/src/scdesigner/diagnose/aic_bic.py +119 -0
- scdesigner-0.0.1/src/scdesigner/diagnose/plot.py +242 -0
- scdesigner-0.0.1/src/scdesigner/estimators/__init__.py +27 -0
- scdesigner-0.0.1/src/scdesigner/estimators/bernoulli.py +85 -0
- scdesigner-0.0.1/src/scdesigner/estimators/gaussian.py +121 -0
- scdesigner-0.0.1/src/scdesigner/estimators/gaussian_copula_factory.py +152 -0
- scdesigner-0.0.1/src/scdesigner/estimators/glm_factory.py +75 -0
- scdesigner-0.0.1/src/scdesigner/estimators/negbin.py +129 -0
- scdesigner-0.0.1/src/scdesigner/estimators/pnmf.py +160 -0
- scdesigner-0.0.1/src/scdesigner/estimators/poisson.py +100 -0
- scdesigner-0.0.1/src/scdesigner/estimators/zero_inflated_negbin.py +195 -0
- scdesigner-0.0.1/src/scdesigner/estimators/zero_inflated_poisson.py +85 -0
- scdesigner-0.0.1/src/scdesigner/format/__init__.py +4 -0
- scdesigner-0.0.1/src/scdesigner/format/format.py +20 -0
- scdesigner-0.0.1/src/scdesigner/format/print.py +30 -0
- scdesigner-0.0.1/src/scdesigner/minimal/__init__.py +17 -0
- scdesigner-0.0.1/src/scdesigner/minimal/bernoulli.py +61 -0
- scdesigner-0.0.1/src/scdesigner/minimal/composite.py +119 -0
- scdesigner-0.0.1/src/scdesigner/minimal/copula.py +33 -0
- scdesigner-0.0.1/src/scdesigner/minimal/formula.py +23 -0
- scdesigner-0.0.1/src/scdesigner/minimal/gaussian.py +65 -0
- scdesigner-0.0.1/src/scdesigner/minimal/kwargs.py +24 -0
- scdesigner-0.0.1/src/scdesigner/minimal/loader.py +166 -0
- scdesigner-0.0.1/src/scdesigner/minimal/marginal.py +140 -0
- scdesigner-0.0.1/src/scdesigner/minimal/negbin.py +73 -0
- scdesigner-0.0.1/src/scdesigner/minimal/positive_nonnegative_matrix_factorization.py +231 -0
- scdesigner-0.0.1/src/scdesigner/minimal/scd3.py +95 -0
- scdesigner-0.0.1/src/scdesigner/minimal/scd3_instances.py +50 -0
- scdesigner-0.0.1/src/scdesigner/minimal/simulator.py +25 -0
- scdesigner-0.0.1/src/scdesigner/minimal/standard_covariance.py +124 -0
- scdesigner-0.0.1/src/scdesigner/minimal/transform.py +145 -0
- scdesigner-0.0.1/src/scdesigner/minimal/zero_inflated_negbin.py +86 -0
- scdesigner-0.0.1/src/scdesigner/predictors/__init__.py +15 -0
- scdesigner-0.0.1/src/scdesigner/predictors/bernoulli.py +9 -0
- scdesigner-0.0.1/src/scdesigner/predictors/gaussian.py +16 -0
- scdesigner-0.0.1/src/scdesigner/predictors/negbin.py +17 -0
- scdesigner-0.0.1/src/scdesigner/predictors/poisson.py +12 -0
- scdesigner-0.0.1/src/scdesigner/predictors/zero_inflated_negbin.py +18 -0
- scdesigner-0.0.1/src/scdesigner/predictors/zero_inflated_poisson.py +18 -0
- scdesigner-0.0.1/src/scdesigner/samplers/__init__.py +23 -0
- scdesigner-0.0.1/src/scdesigner/samplers/bernoulli.py +27 -0
- scdesigner-0.0.1/src/scdesigner/samplers/gaussian.py +25 -0
- scdesigner-0.0.1/src/scdesigner/samplers/glm_factory.py +41 -0
- scdesigner-0.0.1/src/scdesigner/samplers/negbin.py +25 -0
- scdesigner-0.0.1/src/scdesigner/samplers/poisson.py +25 -0
- scdesigner-0.0.1/src/scdesigner/samplers/zero_inflated_negbin.py +40 -0
- scdesigner-0.0.1/src/scdesigner/samplers/zero_inflated_poisson.py +16 -0
- scdesigner-0.0.1/src/scdesigner/simulators/__init__.py +31 -0
- scdesigner-0.0.1/src/scdesigner/simulators/composite_regressor.py +72 -0
- scdesigner-0.0.1/src/scdesigner/simulators/glm_simulator.py +167 -0
- scdesigner-0.0.1/src/scdesigner/simulators/pnmf_regression.py +61 -0
- scdesigner-0.0.1/src/scdesigner/transform/__init__.py +7 -0
- scdesigner-0.0.1/src/scdesigner/transform/amplify.py +14 -0
- scdesigner-0.0.1/src/scdesigner/transform/mask.py +33 -0
- scdesigner-0.0.1/src/scdesigner/transform/nullify.py +25 -0
- scdesigner-0.0.1/src/scdesigner/transform/split.py +23 -0
- scdesigner-0.0.1/src/scdesigner/transform/substitute.py +14 -0
- scdesigner-0.0.1/tests/__init__.py +0 -0
- scdesigner-0.0.1/tests/test_negative_binomial.py +80 -0
- scdesigner-0.0.1/tests/test_simulator.py +38 -0
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: scdesigner
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Interactive simulation for rigorous and transparent multi-omics analysis.
|
|
5
|
+
Project-URL: Homepage, https://github.com/krisrs1128/scDesigner/
|
|
6
|
+
Project-URL: Issues, https://github.com/krisrs1128/scDesigner/Issues/
|
|
7
|
+
Author-email: Kris Sankaran <ksankaran@wisc.edu>
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Operating System :: OS Independent
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Requires-Python: >=3.8
|
|
12
|
+
Requires-Dist: formulaic
|
|
13
|
+
Requires-Dist: numpy
|
|
14
|
+
Requires-Dist: pandas
|
|
15
|
+
Requires-Dist: rich
|
|
16
|
+
Requires-Dist: scanpy
|
|
17
|
+
Requires-Dist: scipy
|
|
18
|
+
Requires-Dist: torch
|
|
19
|
+
Requires-Dist: tqdm
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
### scDesigner
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "scdesigner"
|
|
3
|
+
version = "0.0.1"
|
|
4
|
+
authors = [
|
|
5
|
+
{ name="Kris Sankaran", email="ksankaran@wisc.edu" },
|
|
6
|
+
]
|
|
7
|
+
description = "Interactive simulation for rigorous and transparent multi-omics analysis."
|
|
8
|
+
readme = "README.md"
|
|
9
|
+
requires-python = ">=3.8"
|
|
10
|
+
classifiers = [
|
|
11
|
+
"Programming Language :: Python :: 3",
|
|
12
|
+
"License :: OSI Approved :: MIT License",
|
|
13
|
+
"Operating System :: OS Independent",
|
|
14
|
+
]
|
|
15
|
+
dependencies = [
|
|
16
|
+
"formulaic",
|
|
17
|
+
"numpy",
|
|
18
|
+
"pandas",
|
|
19
|
+
"rich",
|
|
20
|
+
"scanpy",
|
|
21
|
+
"scipy",
|
|
22
|
+
"torch",
|
|
23
|
+
"tqdm"
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
[project.urls]
|
|
27
|
+
Homepage = "https://github.com/krisrs1128/scDesigner/"
|
|
28
|
+
Issues = "https://github.com/krisrs1128/scDesigner/Issues/"
|
|
29
|
+
|
|
30
|
+
[build-system]
|
|
31
|
+
requires = ["formulaic", "pandas", "numpy", "scipy", "hatchling", "torch", "rich"]
|
|
32
|
+
build-backend = "hatchling.build"
|
|
33
|
+
|
|
34
|
+
[tool.hatch.metadata]
|
|
35
|
+
allow-direct-references = true
|
|
File without changes
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from .formula import FormulaViewDataset, formula_loader, multiple_formula_loader, standardize_formula
|
|
2
|
+
from .group import FormulaGroupViewDataset, formula_group_loader, stack_collate, multiple_formula_group_loader
|
|
3
|
+
from .sparse import SparseMatrixDataset, SparseMatrixLoader
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"FormulaViewDataset",
|
|
7
|
+
"SparseMatrixDataset",
|
|
8
|
+
"SparseMatrixLoader",
|
|
9
|
+
"FormulaGroupViewDataset",
|
|
10
|
+
"formula_loader",
|
|
11
|
+
"formula_group_loader",
|
|
12
|
+
"stack_collate",
|
|
13
|
+
"multiple_formula_loader",
|
|
14
|
+
"multiple_formula_group_loader",
|
|
15
|
+
"standardize_formula"
|
|
16
|
+
]
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
from anndata import AnnData
|
|
2
|
+
from formulaic import model_matrix
|
|
3
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
4
|
+
from typing import Union
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import scipy.sparse
|
|
8
|
+
import torch
|
|
9
|
+
import torch.utils.data as td
|
|
10
|
+
import warnings
|
|
11
|
+
|
|
12
|
+
def formula_loader(
|
|
13
|
+
adata: AnnData, formula=None, chunk_size=int(1e4), batch_size: int = None
|
|
14
|
+
):
|
|
15
|
+
device = check_device()
|
|
16
|
+
if adata.isbacked:
|
|
17
|
+
ds = FormulaViewDataset(adata, formula, chunk_size, device)
|
|
18
|
+
dataloader = td.DataLoader(ds, batch_size=batch_size)
|
|
19
|
+
ds.x_names = model_matrix_names(adata, formula, ds.categories)
|
|
20
|
+
else:
|
|
21
|
+
# convert sparse to dense matrix
|
|
22
|
+
y = adata.X
|
|
23
|
+
if isinstance(y, scipy.sparse._csc.csc_matrix):
|
|
24
|
+
y = y.todense()
|
|
25
|
+
|
|
26
|
+
# create tensor-based loader
|
|
27
|
+
x = model_matrix(formula, pd.DataFrame(adata.obs))
|
|
28
|
+
ds = TensorDataset(
|
|
29
|
+
torch.tensor(np.array(x), dtype=torch.float32).to(device),
|
|
30
|
+
torch.tensor(y, dtype=torch.float32).to(device),
|
|
31
|
+
)
|
|
32
|
+
ds.x_names = list(x.columns)
|
|
33
|
+
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False)
|
|
34
|
+
|
|
35
|
+
return dataloader
|
|
36
|
+
|
|
37
|
+
def multiple_formula_loader(
|
|
38
|
+
adata: AnnData, formulas: dict, chunk_size=int(1e4), batch_size: int = None
|
|
39
|
+
):
|
|
40
|
+
dataloaders = {}
|
|
41
|
+
for key in formulas.keys():
|
|
42
|
+
dataloaders[key] = formula_loader(adata, formulas[key], chunk_size, batch_size)
|
|
43
|
+
return dataloaders
|
|
44
|
+
|
|
45
|
+
class FormulaViewDataset(td.Dataset):
|
|
46
|
+
def __init__(self, view, formula=None, chunk_size=int(1e4), device=None):
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.view = view
|
|
49
|
+
self.formula = formula
|
|
50
|
+
self.len = len(view)
|
|
51
|
+
self.cur_range = range(0, min(self.len, chunk_size))
|
|
52
|
+
self.categories = column_levels(view.obs)
|
|
53
|
+
self.x = None
|
|
54
|
+
self.y = None
|
|
55
|
+
self.device = device or check_device()
|
|
56
|
+
|
|
57
|
+
def __len__(self):
|
|
58
|
+
return self.len
|
|
59
|
+
|
|
60
|
+
def __getitem__(self, ix):
|
|
61
|
+
if self.x is None or ix not in self.cur_range:
|
|
62
|
+
self.cur_range = range(ix, min(ix + len(self.cur_range), self.len))
|
|
63
|
+
view_inmem = self.view[self.cur_range].to_memory()
|
|
64
|
+
self.x = safe_model_matrix(
|
|
65
|
+
view_inmem.obs, self.formula, self.categories
|
|
66
|
+
).to(self.device)
|
|
67
|
+
self.y = torch.from_numpy(view_inmem.X.toarray().astype(np.float32)).to(
|
|
68
|
+
self.device
|
|
69
|
+
)
|
|
70
|
+
return self.x[ix - self.cur_range[0]], self.y[ix - self.cur_range[0]]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def replace_cols(obs, categories):
|
|
74
|
+
for k in obs.columns:
|
|
75
|
+
if str(obs[k].dtype) == "category":
|
|
76
|
+
obs[k] = obs[k].astype(pd.CategoricalDtype(categories[k]))
|
|
77
|
+
return obs
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def model_matrix_names(adata, formula, categories):
|
|
81
|
+
if adata.isbacked:
|
|
82
|
+
obs = adata[:1].to_memory().obs
|
|
83
|
+
else:
|
|
84
|
+
obs = adata.obs
|
|
85
|
+
|
|
86
|
+
if formula is None:
|
|
87
|
+
return list(obs.columns)
|
|
88
|
+
|
|
89
|
+
obs = replace_cols(obs, categories)
|
|
90
|
+
return list(model_matrix(formula, pd.DataFrame(obs)).columns)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def safe_model_matrix(obs, formula, categories):
|
|
94
|
+
if formula is None:
|
|
95
|
+
return obs
|
|
96
|
+
|
|
97
|
+
obs = replace_cols(obs, categories)
|
|
98
|
+
x = model_matrix(formula, pd.DataFrame(obs))
|
|
99
|
+
return torch.from_numpy(np.array(x).astype(np.float32))
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def column_levels(obs):
|
|
103
|
+
categories = {}
|
|
104
|
+
for k in obs.columns:
|
|
105
|
+
obs_type = str(obs[k].dtype)
|
|
106
|
+
if obs_type in ["category", "object"]:
|
|
107
|
+
categories[k] = obs[k].unique()
|
|
108
|
+
return categories
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def check_device():
|
|
112
|
+
return torch.device(
|
|
113
|
+
"cuda"
|
|
114
|
+
if torch.cuda.is_available()
|
|
115
|
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def standardize_formula(formula: Union[str, dict], allowed_keys = None):
|
|
119
|
+
# The first element of allowed_keys should be the name of default parameter
|
|
120
|
+
if allowed_keys is None:
|
|
121
|
+
raise ValueError("Internal error: allowed_keys must be specified")
|
|
122
|
+
formula = {allowed_keys[0]: formula} if isinstance(formula, str) else formula
|
|
123
|
+
|
|
124
|
+
formula_keys = set(formula.keys())
|
|
125
|
+
allowed_keys = set(allowed_keys)
|
|
126
|
+
|
|
127
|
+
if not formula_keys & allowed_keys:
|
|
128
|
+
raise ValueError(f"formula must have at least one of the following keys: {allowed_keys}")
|
|
129
|
+
|
|
130
|
+
if extra_keys := formula_keys - allowed_keys:
|
|
131
|
+
warnings.warn(
|
|
132
|
+
f"Invalid formulas in dictionary will not be used: {extra_keys}",
|
|
133
|
+
UserWarning,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
formula.update({k: '~ 1' for k in allowed_keys - formula_keys})
|
|
137
|
+
return formula
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from . import formula as fl
|
|
2
|
+
from anndata import AnnData
|
|
3
|
+
from formulaic import model_matrix
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import scipy
|
|
7
|
+
import torch
|
|
8
|
+
import torch.utils.data as td
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def formula_group_loader(
|
|
12
|
+
adata: AnnData,
|
|
13
|
+
formula=None,
|
|
14
|
+
grouping_variable=None,
|
|
15
|
+
chunk_size=int(1e4),
|
|
16
|
+
batch_size: int = None,
|
|
17
|
+
):
|
|
18
|
+
device = fl.check_device()
|
|
19
|
+
if grouping_variable is None:
|
|
20
|
+
adata.obs["_copula_group"] = "shared_group"
|
|
21
|
+
grouping_variable = "_copula_group"
|
|
22
|
+
adata.obs["_copula_group"] = adata.obs["_copula_group"].astype("category")
|
|
23
|
+
|
|
24
|
+
if adata.isbacked:
|
|
25
|
+
ds = FormulaGroupViewDataset(
|
|
26
|
+
adata, formula, grouping_variable, chunk_size, device
|
|
27
|
+
)
|
|
28
|
+
dataloader = td.DataLoader(ds, batch_size=batch_size, collate_fn=stack_collate())
|
|
29
|
+
ds.x_names = fl.model_matrix_names(adata, formula, ds.categories)
|
|
30
|
+
else:
|
|
31
|
+
# convert sparse to dense matrix
|
|
32
|
+
y = adata.X
|
|
33
|
+
if isinstance(y, scipy.sparse._csc.csc_matrix):
|
|
34
|
+
y = y.todense()
|
|
35
|
+
|
|
36
|
+
# wrap the entire data into a dataset
|
|
37
|
+
x = model_matrix(formula, pd.DataFrame(adata.obs))
|
|
38
|
+
ds = td.StackDataset(
|
|
39
|
+
x=td.TensorDataset(
|
|
40
|
+
torch.tensor(np.array(x), dtype=torch.float32).to(device)
|
|
41
|
+
),
|
|
42
|
+
y=td.TensorDataset(torch.tensor(y, dtype=torch.float32).to(device)),
|
|
43
|
+
groups=ListDataset(adata.obs[grouping_variable]),
|
|
44
|
+
)
|
|
45
|
+
ds.groups = list(adata.obs[grouping_variable].dtype.categories)
|
|
46
|
+
ds.x_names = list(x.columns)
|
|
47
|
+
dataloader = td.DataLoader(ds, batch_size=batch_size, collate_fn=stack_collate(pop=True))
|
|
48
|
+
|
|
49
|
+
return dataloader
|
|
50
|
+
|
|
51
|
+
def multiple_formula_group_loader(adata: AnnData, formulas: dict, grouping_variable=None,
|
|
52
|
+
chunk_size=int(1e4), batch_size: int = None):
|
|
53
|
+
dataloaders = {}
|
|
54
|
+
for key in formulas.keys():
|
|
55
|
+
dataloaders[key] = formula_group_loader(adata, formulas[key], grouping_variable, chunk_size, batch_size)
|
|
56
|
+
return dataloaders
|
|
57
|
+
|
|
58
|
+
class FormulaGroupViewDataset(td.Dataset):
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
view,
|
|
62
|
+
formula=None,
|
|
63
|
+
grouping_variable=None,
|
|
64
|
+
chunk_size=int(1e4),
|
|
65
|
+
device=None,
|
|
66
|
+
):
|
|
67
|
+
super().__init__()
|
|
68
|
+
self.device = device or fl.check_device()
|
|
69
|
+
self.formula = formula
|
|
70
|
+
self.categories = fl.column_levels(view.obs)
|
|
71
|
+
self.grouping_variable = grouping_variable
|
|
72
|
+
self.groups = list(self.categories[grouping_variable].dtype.categories)
|
|
73
|
+
self.len = len(view)
|
|
74
|
+
self.memberships = None
|
|
75
|
+
self.view = view
|
|
76
|
+
self.x = None
|
|
77
|
+
self.y = None
|
|
78
|
+
self.cur_range = range(0, min(self.len, chunk_size))
|
|
79
|
+
|
|
80
|
+
def __len__(self):
|
|
81
|
+
return self.len
|
|
82
|
+
|
|
83
|
+
def __getitem__(self, ix):
|
|
84
|
+
if self.x is None or ix not in self.cur_range:
|
|
85
|
+
self.cur_range = range(ix, min(ix + len(self.cur_range), self.len))
|
|
86
|
+
view_inmem = self.view[self.cur_range].to_memory()
|
|
87
|
+
self.memberships = view_inmem.obs[self.grouping_variable]
|
|
88
|
+
self.x = fl.safe_model_matrix(
|
|
89
|
+
view_inmem.obs, self.formula, self.categories
|
|
90
|
+
).to(self.device)
|
|
91
|
+
self.y = torch.from_numpy(view_inmem.X.toarray().astype(np.float32)).to(
|
|
92
|
+
self.device
|
|
93
|
+
)
|
|
94
|
+
return {
|
|
95
|
+
"x": self.x[ix - self.cur_range[0]],
|
|
96
|
+
"y": self.y[ix - self.cur_range[0]],
|
|
97
|
+
"groups": self.memberships[ix - self.cur_range[0]],
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class ListDataset(td.Dataset):
|
|
102
|
+
"""
|
|
103
|
+
Simple DS to store groups
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __init__(self, list):
|
|
107
|
+
self.list = list
|
|
108
|
+
|
|
109
|
+
def __len__(self):
|
|
110
|
+
return len(self.list)
|
|
111
|
+
|
|
112
|
+
def __getitem__(self, idx):
|
|
113
|
+
return self.list[idx]
|
|
114
|
+
|
|
115
|
+
def stack_collate(pop=False, groups=True):
|
|
116
|
+
def f(batch):
|
|
117
|
+
x = torch.stack([sample["x"][0] if pop else sample["x"] for sample in batch])
|
|
118
|
+
y = torch.stack([sample["y"][0] if pop else sample["y"] for sample in batch])
|
|
119
|
+
if groups:
|
|
120
|
+
G = tuple([sample["groups"] for sample in batch])
|
|
121
|
+
return [x, y, G]
|
|
122
|
+
return [x, y]
|
|
123
|
+
return f
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import anndata
|
|
2
|
+
import torch
|
|
3
|
+
import torch.utils.data as td
|
|
4
|
+
|
|
5
|
+
class SparseMatrixLoader:
|
|
6
|
+
def __init__(self, adata: anndata.AnnData, batch_size: int = None):
|
|
7
|
+
ds = SparseMatrixDataset(adata, batch_size)
|
|
8
|
+
self.loader = td.DataLoader(ds, batch_size=None)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SparseMatrixDataset(td.IterableDataset):
|
|
12
|
+
def __init__(self, anndata: anndata.AnnData, batch_size: int = None):
|
|
13
|
+
self.n_rows = anndata.X.shape[0]
|
|
14
|
+
if batch_size is None:
|
|
15
|
+
batch_size = self.n_rows
|
|
16
|
+
|
|
17
|
+
self.sparse_matrix = anndata.X
|
|
18
|
+
self.batch_size = batch_size
|
|
19
|
+
|
|
20
|
+
def __iter__(self):
|
|
21
|
+
for i in range(0, self.n_rows, self.batch_size):
|
|
22
|
+
batch_indices = range(i, min(i + self.batch_size, self.n_rows))
|
|
23
|
+
batch_rows = self.sparse_matrix[batch_indices, :]
|
|
24
|
+
|
|
25
|
+
# Convert to sparse CSR tensor
|
|
26
|
+
batch_indices_rows, batch_indices_cols = batch_rows.nonzero()
|
|
27
|
+
batch_values = batch_rows.data
|
|
28
|
+
|
|
29
|
+
batch_sparse_tensor = torch.sparse_coo_tensor(
|
|
30
|
+
torch.tensor([batch_indices_rows, batch_indices_cols]),
|
|
31
|
+
torch.tensor(batch_values, dtype=torch.float32),
|
|
32
|
+
(len(batch_indices), self.sparse_matrix.shape[1]),
|
|
33
|
+
).to_sparse_csr()
|
|
34
|
+
|
|
35
|
+
yield batch_sparse_tensor
|
|
36
|
+
|
|
37
|
+
def __len__(self):
|
|
38
|
+
return (self.n_rows + self.batch_size - 1) // self.batch_size
|
|
39
|
+
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from .plot import (
|
|
2
|
+
plot_umap,
|
|
3
|
+
plot_hist,
|
|
4
|
+
compare_means,
|
|
5
|
+
compare_variances,
|
|
6
|
+
compare_standard_deviation,
|
|
7
|
+
compare_umap,
|
|
8
|
+
compare_pca,
|
|
9
|
+
)
|
|
10
|
+
from .aic_bic import compose_marginal_diagnose, compose_gcopula_diagnose
|
|
11
|
+
from .. import estimators as est
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"bernoulli_gcopula_diagnose",
|
|
15
|
+
"bernoulli_regression_diagnose",
|
|
16
|
+
"compare_means",
|
|
17
|
+
"compare_pca",
|
|
18
|
+
"compare_standard_deviation",
|
|
19
|
+
"compare_umap",
|
|
20
|
+
"compare_variances",
|
|
21
|
+
"gaussian_regression_diagnose",
|
|
22
|
+
"negbin_gcopula_diagnose",
|
|
23
|
+
"negbin_regression_diagnose",
|
|
24
|
+
"plot_hist",
|
|
25
|
+
"plot_pca",
|
|
26
|
+
"plot_umap",
|
|
27
|
+
"poisson_gcopula_diagnose",
|
|
28
|
+
"poisson_regression_diagnose",
|
|
29
|
+
"zinb_gcopula_diagnose",
|
|
30
|
+
"zinb_regression_diagnose",
|
|
31
|
+
"zip_regression_diagnose"
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
###############################################################################
|
|
36
|
+
## Methods for calculating marginal/gaussian copula AIC/BIC
|
|
37
|
+
###############################################################################
|
|
38
|
+
|
|
39
|
+
negbin_regression_diagnose = compose_marginal_diagnose(est.negbin.negbin_regression_likelihood,
|
|
40
|
+
allowed_keys=['mean', 'dispersion'])
|
|
41
|
+
negbin_gcopula_diagnose = compose_gcopula_diagnose(est.negbin.negbin_regression_likelihood,
|
|
42
|
+
est.negbin.negbin_uniformizer,
|
|
43
|
+
allowed_keys=['mean', 'dispersion'])
|
|
44
|
+
poisson_regression_diagnose = compose_marginal_diagnose(est.poisson.poisson_regression_likelihood,
|
|
45
|
+
allowed_keys=['mean'])
|
|
46
|
+
poisson_gcopula_diagnose = compose_gcopula_diagnose(est.poisson.poisson_regression_likelihood,
|
|
47
|
+
est.poisson.poisson_uniformizer,
|
|
48
|
+
allowed_keys=['mean'])
|
|
49
|
+
bernoulli_regression_diagnose = compose_marginal_diagnose(est.bernoulli.bernoulli_regression_likelihood,
|
|
50
|
+
allowed_keys=['mean'])
|
|
51
|
+
bernoulli_gcopula_diagnose = compose_gcopula_diagnose(est.bernoulli.bernoulli_regression_likelihood,
|
|
52
|
+
est.bernoulli.bernoulli_uniformizer,
|
|
53
|
+
allowed_keys=['mean'])
|
|
54
|
+
zinb_regression_diagnose = compose_marginal_diagnose(est.zero_inflated_negbin.zero_inflated_negbin_regression_likelihood,
|
|
55
|
+
allowed_keys=['mean', 'dispersion', 'zero_inflation'])
|
|
56
|
+
zinb_gcopula_diagnose = compose_gcopula_diagnose(est.zero_inflated_negbin.zero_inflated_negbin_regression_likelihood,
|
|
57
|
+
est.zero_inflated_negbin.zero_inflated_negbin_uniformizer,
|
|
58
|
+
allowed_keys=['mean', 'dispersion', 'zero_inflation'])
|
|
59
|
+
zip_regression_diagnose = compose_marginal_diagnose(est.zero_inflated_poisson.zero_inflated_poisson_regression_likelihood,
|
|
60
|
+
allowed_keys=['mean', 'zero_inflation'])
|
|
61
|
+
gaussian_regression_diagnose = compose_marginal_diagnose(est.gaussian.gaussian_regression_likelihood,
|
|
62
|
+
allowed_keys=['mean', 'sdev'])
|
|
63
|
+
gaussian_gcopula_diagnose = compose_gcopula_diagnose(est.gaussian.gaussian_regression_likelihood,
|
|
64
|
+
est.gaussian.gaussian_uniformizer,
|
|
65
|
+
allowed_keys=['mean', 'sdev'])
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from .. import data
|
|
2
|
+
from anndata import AnnData
|
|
3
|
+
from formulaic import model_matrix
|
|
4
|
+
from scipy.stats import norm, multivariate_normal
|
|
5
|
+
from typing import Union
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import torch, scipy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def marginal_aic_bic(likelihood, params: dict, adata: AnnData,
|
|
12
|
+
formula: Union[str, dict], allowed_keys: set,
|
|
13
|
+
param_order: list = None, transform: list = None,
|
|
14
|
+
chunk_size: int = int(1e4), batch_size=512):
|
|
15
|
+
device = data.formula.check_device()
|
|
16
|
+
nsample = len(adata)
|
|
17
|
+
params = likelihood_unwrapper(params, param_order, transform).to(device)
|
|
18
|
+
nparam = len(params)
|
|
19
|
+
|
|
20
|
+
# create batches for likelihood calculation
|
|
21
|
+
formula = data.standardize_formula(formula, allowed_keys)
|
|
22
|
+
loader = data.multiple_formula_loader(
|
|
23
|
+
adata, formula, chunk_size=chunk_size, batch_size=batch_size
|
|
24
|
+
)
|
|
25
|
+
keys = list(loader.keys())
|
|
26
|
+
loaders = list(loader.values())
|
|
27
|
+
num_keys = len(keys)
|
|
28
|
+
|
|
29
|
+
ll = 0
|
|
30
|
+
with torch.no_grad():
|
|
31
|
+
for batches in zip(*loaders):
|
|
32
|
+
x = {keys[i]: batches[i][0].to(device) for i in range(num_keys)}
|
|
33
|
+
y = batches[0][1].to(device)
|
|
34
|
+
ll += -likelihood(params, x, y)
|
|
35
|
+
aic = 2 * nparam - 2 * ll
|
|
36
|
+
bic = np.log(nsample) * nparam - 2 * ll
|
|
37
|
+
return aic.cpu().item(), bic.cpu().item()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def gaussian_copula_aic_bic(uniformizer, params: dict, adata: AnnData,
|
|
41
|
+
formula: Union[str, dict], allowed_keys: set, copula_groups=None):
|
|
42
|
+
params = uniformizer_unwrapper(params)
|
|
43
|
+
covariance = params['covariance']
|
|
44
|
+
y = adata.X
|
|
45
|
+
if isinstance(y, scipy.sparse._csc.csc_matrix):
|
|
46
|
+
y = y.todense()
|
|
47
|
+
formula = data.standardize_formula(formula, allowed_keys)
|
|
48
|
+
X = {key: model_matrix(formula[key], pd.DataFrame(adata.obs)) for key in formula}
|
|
49
|
+
if copula_groups is not None:
|
|
50
|
+
memberships = adata.obs[copula_groups]
|
|
51
|
+
else:
|
|
52
|
+
copula_groups = "shared_group"
|
|
53
|
+
memberships = np.array(["shared_group"] * y.shape[0])
|
|
54
|
+
|
|
55
|
+
u = uniformizer(params, X, y)
|
|
56
|
+
groups = covariance.keys()
|
|
57
|
+
nparam = {
|
|
58
|
+
g: (np.sum(covariance[g] != 0) - covariance[g].shape[0]) / 2 for g in groups
|
|
59
|
+
}
|
|
60
|
+
aic = 0 # in the future may add group-wise AIC/BIC
|
|
61
|
+
bic = 0
|
|
62
|
+
for g in groups:
|
|
63
|
+
ix = np.where(memberships == g)[0]
|
|
64
|
+
z = norm().ppf(u[ix])
|
|
65
|
+
copula_ll = multivariate_normal.logpdf(
|
|
66
|
+
z, np.zeros(covariance[g].shape[0]), covariance[g]
|
|
67
|
+
)
|
|
68
|
+
marginal_ll = norm.logpdf(z)
|
|
69
|
+
aic += -2 * (np.sum(copula_ll) - np.sum(marginal_ll)) + 2 * nparam[g]
|
|
70
|
+
bic += (
|
|
71
|
+
-2 * (np.sum(copula_ll) - np.sum(marginal_ll))
|
|
72
|
+
+ np.log(z.shape[0]) * nparam[g]
|
|
73
|
+
)
|
|
74
|
+
return aic, bic
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def compose_marginal_diagnose(likelihood, allowed_keys: set, param_order: list = None, transform: list = None):
|
|
78
|
+
def diagnose(params: dict, adata: AnnData, formula: Union[str, dict],
|
|
79
|
+
chunk_size: int = int(1e4), batch_size=512):
|
|
80
|
+
return marginal_aic_bic(likelihood, params, adata, formula, allowed_keys,
|
|
81
|
+
param_order, transform, chunk_size, batch_size)
|
|
82
|
+
return diagnose
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def compose_gcopula_diagnose(likelihood, uniformizer, allowed_keys: set,
|
|
86
|
+
param_order: list = None, transform: list = None):
|
|
87
|
+
def diagnose(params: dict, adata: AnnData, formula: str, copula_groups=None,
|
|
88
|
+
chunk_size: int = int(1e4), batch_size=512):
|
|
89
|
+
marginal_aic, marginal_bic = marginal_aic_bic(likelihood, params, adata, formula, allowed_keys,
|
|
90
|
+
param_order, transform, chunk_size, batch_size)
|
|
91
|
+
copula_aic, copula_bic = gaussian_copula_aic_bic(uniformizer, params, adata, formula, allowed_keys, copula_groups)
|
|
92
|
+
return marginal_aic, marginal_bic, copula_aic, copula_bic
|
|
93
|
+
return diagnose
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
###############################################################################
|
|
97
|
+
## Helper for converting params to match likelihood/uniformizer input format
|
|
98
|
+
###############################################################################
|
|
99
|
+
|
|
100
|
+
def likelihood_unwrapper(params: dict, param_order: list = None, transform: list = None):
|
|
101
|
+
l = []
|
|
102
|
+
keys_to_process = [k for k in params if k != "covariance"] if param_order is None else param_order
|
|
103
|
+
|
|
104
|
+
for idx, k in enumerate(keys_to_process):
|
|
105
|
+
feature = params[k]
|
|
106
|
+
v = torch.Tensor(feature.values).reshape(1, feature.shape[0] * feature.shape[1])[0]
|
|
107
|
+
if transform is not None:
|
|
108
|
+
v = transform[idx](v)
|
|
109
|
+
l.append(v)
|
|
110
|
+
|
|
111
|
+
return torch.cat(l, dim=0)
|
|
112
|
+
|
|
113
|
+
def uniformizer_unwrapper(params):
|
|
114
|
+
params = params = {key: params[key].values if key!='covariance' else params[key] for key in params}
|
|
115
|
+
if not isinstance(params['covariance'], dict):
|
|
116
|
+
params['covariance'] = {'shared_group': params['covariance'].values}
|
|
117
|
+
else:
|
|
118
|
+
params['covariance'] = {key: params['covariance'][key].values for key in params['covariance']}
|
|
119
|
+
return params
|