scdesigner 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scdesigner might be problematic. Click here for more details.

Files changed (70) hide show
  1. scdesigner-0.0.1/.gitignore +8 -0
  2. scdesigner-0.0.1/PKG-INFO +23 -0
  3. scdesigner-0.0.1/README.md +2 -0
  4. scdesigner-0.0.1/pyproject.toml +35 -0
  5. scdesigner-0.0.1/src/scdesigner/__init__.py +0 -0
  6. scdesigner-0.0.1/src/scdesigner/data/__init__.py +16 -0
  7. scdesigner-0.0.1/src/scdesigner/data/formula.py +137 -0
  8. scdesigner-0.0.1/src/scdesigner/data/group.py +123 -0
  9. scdesigner-0.0.1/src/scdesigner/data/sparse.py +39 -0
  10. scdesigner-0.0.1/src/scdesigner/diagnose/__init__.py +65 -0
  11. scdesigner-0.0.1/src/scdesigner/diagnose/aic_bic.py +119 -0
  12. scdesigner-0.0.1/src/scdesigner/diagnose/plot.py +242 -0
  13. scdesigner-0.0.1/src/scdesigner/estimators/__init__.py +27 -0
  14. scdesigner-0.0.1/src/scdesigner/estimators/bernoulli.py +85 -0
  15. scdesigner-0.0.1/src/scdesigner/estimators/gaussian.py +121 -0
  16. scdesigner-0.0.1/src/scdesigner/estimators/gaussian_copula_factory.py +152 -0
  17. scdesigner-0.0.1/src/scdesigner/estimators/glm_factory.py +75 -0
  18. scdesigner-0.0.1/src/scdesigner/estimators/negbin.py +129 -0
  19. scdesigner-0.0.1/src/scdesigner/estimators/pnmf.py +160 -0
  20. scdesigner-0.0.1/src/scdesigner/estimators/poisson.py +100 -0
  21. scdesigner-0.0.1/src/scdesigner/estimators/zero_inflated_negbin.py +195 -0
  22. scdesigner-0.0.1/src/scdesigner/estimators/zero_inflated_poisson.py +85 -0
  23. scdesigner-0.0.1/src/scdesigner/format/__init__.py +4 -0
  24. scdesigner-0.0.1/src/scdesigner/format/format.py +20 -0
  25. scdesigner-0.0.1/src/scdesigner/format/print.py +30 -0
  26. scdesigner-0.0.1/src/scdesigner/minimal/__init__.py +17 -0
  27. scdesigner-0.0.1/src/scdesigner/minimal/bernoulli.py +61 -0
  28. scdesigner-0.0.1/src/scdesigner/minimal/composite.py +119 -0
  29. scdesigner-0.0.1/src/scdesigner/minimal/copula.py +33 -0
  30. scdesigner-0.0.1/src/scdesigner/minimal/formula.py +23 -0
  31. scdesigner-0.0.1/src/scdesigner/minimal/gaussian.py +65 -0
  32. scdesigner-0.0.1/src/scdesigner/minimal/kwargs.py +24 -0
  33. scdesigner-0.0.1/src/scdesigner/minimal/loader.py +166 -0
  34. scdesigner-0.0.1/src/scdesigner/minimal/marginal.py +140 -0
  35. scdesigner-0.0.1/src/scdesigner/minimal/negbin.py +73 -0
  36. scdesigner-0.0.1/src/scdesigner/minimal/positive_nonnegative_matrix_factorization.py +231 -0
  37. scdesigner-0.0.1/src/scdesigner/minimal/scd3.py +95 -0
  38. scdesigner-0.0.1/src/scdesigner/minimal/scd3_instances.py +50 -0
  39. scdesigner-0.0.1/src/scdesigner/minimal/simulator.py +25 -0
  40. scdesigner-0.0.1/src/scdesigner/minimal/standard_covariance.py +124 -0
  41. scdesigner-0.0.1/src/scdesigner/minimal/transform.py +145 -0
  42. scdesigner-0.0.1/src/scdesigner/minimal/zero_inflated_negbin.py +86 -0
  43. scdesigner-0.0.1/src/scdesigner/predictors/__init__.py +15 -0
  44. scdesigner-0.0.1/src/scdesigner/predictors/bernoulli.py +9 -0
  45. scdesigner-0.0.1/src/scdesigner/predictors/gaussian.py +16 -0
  46. scdesigner-0.0.1/src/scdesigner/predictors/negbin.py +17 -0
  47. scdesigner-0.0.1/src/scdesigner/predictors/poisson.py +12 -0
  48. scdesigner-0.0.1/src/scdesigner/predictors/zero_inflated_negbin.py +18 -0
  49. scdesigner-0.0.1/src/scdesigner/predictors/zero_inflated_poisson.py +18 -0
  50. scdesigner-0.0.1/src/scdesigner/samplers/__init__.py +23 -0
  51. scdesigner-0.0.1/src/scdesigner/samplers/bernoulli.py +27 -0
  52. scdesigner-0.0.1/src/scdesigner/samplers/gaussian.py +25 -0
  53. scdesigner-0.0.1/src/scdesigner/samplers/glm_factory.py +41 -0
  54. scdesigner-0.0.1/src/scdesigner/samplers/negbin.py +25 -0
  55. scdesigner-0.0.1/src/scdesigner/samplers/poisson.py +25 -0
  56. scdesigner-0.0.1/src/scdesigner/samplers/zero_inflated_negbin.py +40 -0
  57. scdesigner-0.0.1/src/scdesigner/samplers/zero_inflated_poisson.py +16 -0
  58. scdesigner-0.0.1/src/scdesigner/simulators/__init__.py +31 -0
  59. scdesigner-0.0.1/src/scdesigner/simulators/composite_regressor.py +72 -0
  60. scdesigner-0.0.1/src/scdesigner/simulators/glm_simulator.py +167 -0
  61. scdesigner-0.0.1/src/scdesigner/simulators/pnmf_regression.py +61 -0
  62. scdesigner-0.0.1/src/scdesigner/transform/__init__.py +7 -0
  63. scdesigner-0.0.1/src/scdesigner/transform/amplify.py +14 -0
  64. scdesigner-0.0.1/src/scdesigner/transform/mask.py +33 -0
  65. scdesigner-0.0.1/src/scdesigner/transform/nullify.py +25 -0
  66. scdesigner-0.0.1/src/scdesigner/transform/split.py +23 -0
  67. scdesigner-0.0.1/src/scdesigner/transform/substitute.py +14 -0
  68. scdesigner-0.0.1/tests/__init__.py +0 -0
  69. scdesigner-0.0.1/tests/test_negative_binomial.py +80 -0
  70. scdesigner-0.0.1/tests/test_simulator.py +38 -0
@@ -0,0 +1,8 @@
1
+ build*
2
+ *egg*
3
+ *cache*
4
+ examples/data/*.h5ad
5
+ examples/data
6
+ *logs*
7
+ *DS_store
8
+ tests_langtian/
@@ -0,0 +1,23 @@
1
+ Metadata-Version: 2.4
2
+ Name: scdesigner
3
+ Version: 0.0.1
4
+ Summary: Interactive simulation for rigorous and transparent multi-omics analysis.
5
+ Project-URL: Homepage, https://github.com/krisrs1128/scDesigner/
6
+ Project-URL: Issues, https://github.com/krisrs1128/scDesigner/Issues/
7
+ Author-email: Kris Sankaran <ksankaran@wisc.edu>
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Classifier: Programming Language :: Python :: 3
11
+ Requires-Python: >=3.8
12
+ Requires-Dist: formulaic
13
+ Requires-Dist: numpy
14
+ Requires-Dist: pandas
15
+ Requires-Dist: rich
16
+ Requires-Dist: scanpy
17
+ Requires-Dist: scipy
18
+ Requires-Dist: torch
19
+ Requires-Dist: tqdm
20
+ Description-Content-Type: text/markdown
21
+
22
+
23
+ ### scDesigner
@@ -0,0 +1,2 @@
1
+
2
+ ### scDesigner
@@ -0,0 +1,35 @@
1
+ [project]
2
+ name = "scdesigner"
3
+ version = "0.0.1"
4
+ authors = [
5
+ { name="Kris Sankaran", email="ksankaran@wisc.edu" },
6
+ ]
7
+ description = "Interactive simulation for rigorous and transparent multi-omics analysis."
8
+ readme = "README.md"
9
+ requires-python = ">=3.8"
10
+ classifiers = [
11
+ "Programming Language :: Python :: 3",
12
+ "License :: OSI Approved :: MIT License",
13
+ "Operating System :: OS Independent",
14
+ ]
15
+ dependencies = [
16
+ "formulaic",
17
+ "numpy",
18
+ "pandas",
19
+ "rich",
20
+ "scanpy",
21
+ "scipy",
22
+ "torch",
23
+ "tqdm"
24
+ ]
25
+
26
+ [project.urls]
27
+ Homepage = "https://github.com/krisrs1128/scDesigner/"
28
+ Issues = "https://github.com/krisrs1128/scDesigner/Issues/"
29
+
30
+ [build-system]
31
+ requires = ["formulaic", "pandas", "numpy", "scipy", "hatchling", "torch", "rich"]
32
+ build-backend = "hatchling.build"
33
+
34
+ [tool.hatch.metadata]
35
+ allow-direct-references = true
File without changes
@@ -0,0 +1,16 @@
1
+ from .formula import FormulaViewDataset, formula_loader, multiple_formula_loader, standardize_formula
2
+ from .group import FormulaGroupViewDataset, formula_group_loader, stack_collate, multiple_formula_group_loader
3
+ from .sparse import SparseMatrixDataset, SparseMatrixLoader
4
+
5
+ __all__ = [
6
+ "FormulaViewDataset",
7
+ "SparseMatrixDataset",
8
+ "SparseMatrixLoader",
9
+ "FormulaGroupViewDataset",
10
+ "formula_loader",
11
+ "formula_group_loader",
12
+ "stack_collate",
13
+ "multiple_formula_loader",
14
+ "multiple_formula_group_loader",
15
+ "standardize_formula"
16
+ ]
@@ -0,0 +1,137 @@
1
+ from anndata import AnnData
2
+ from formulaic import model_matrix
3
+ from torch.utils.data import DataLoader, TensorDataset
4
+ from typing import Union
5
+ import numpy as np
6
+ import pandas as pd
7
+ import scipy.sparse
8
+ import torch
9
+ import torch.utils.data as td
10
+ import warnings
11
+
12
+ def formula_loader(
13
+ adata: AnnData, formula=None, chunk_size=int(1e4), batch_size: int = None
14
+ ):
15
+ device = check_device()
16
+ if adata.isbacked:
17
+ ds = FormulaViewDataset(adata, formula, chunk_size, device)
18
+ dataloader = td.DataLoader(ds, batch_size=batch_size)
19
+ ds.x_names = model_matrix_names(adata, formula, ds.categories)
20
+ else:
21
+ # convert sparse to dense matrix
22
+ y = adata.X
23
+ if isinstance(y, scipy.sparse._csc.csc_matrix):
24
+ y = y.todense()
25
+
26
+ # create tensor-based loader
27
+ x = model_matrix(formula, pd.DataFrame(adata.obs))
28
+ ds = TensorDataset(
29
+ torch.tensor(np.array(x), dtype=torch.float32).to(device),
30
+ torch.tensor(y, dtype=torch.float32).to(device),
31
+ )
32
+ ds.x_names = list(x.columns)
33
+ dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False)
34
+
35
+ return dataloader
36
+
37
+ def multiple_formula_loader(
38
+ adata: AnnData, formulas: dict, chunk_size=int(1e4), batch_size: int = None
39
+ ):
40
+ dataloaders = {}
41
+ for key in formulas.keys():
42
+ dataloaders[key] = formula_loader(adata, formulas[key], chunk_size, batch_size)
43
+ return dataloaders
44
+
45
+ class FormulaViewDataset(td.Dataset):
46
+ def __init__(self, view, formula=None, chunk_size=int(1e4), device=None):
47
+ super().__init__()
48
+ self.view = view
49
+ self.formula = formula
50
+ self.len = len(view)
51
+ self.cur_range = range(0, min(self.len, chunk_size))
52
+ self.categories = column_levels(view.obs)
53
+ self.x = None
54
+ self.y = None
55
+ self.device = device or check_device()
56
+
57
+ def __len__(self):
58
+ return self.len
59
+
60
+ def __getitem__(self, ix):
61
+ if self.x is None or ix not in self.cur_range:
62
+ self.cur_range = range(ix, min(ix + len(self.cur_range), self.len))
63
+ view_inmem = self.view[self.cur_range].to_memory()
64
+ self.x = safe_model_matrix(
65
+ view_inmem.obs, self.formula, self.categories
66
+ ).to(self.device)
67
+ self.y = torch.from_numpy(view_inmem.X.toarray().astype(np.float32)).to(
68
+ self.device
69
+ )
70
+ return self.x[ix - self.cur_range[0]], self.y[ix - self.cur_range[0]]
71
+
72
+
73
+ def replace_cols(obs, categories):
74
+ for k in obs.columns:
75
+ if str(obs[k].dtype) == "category":
76
+ obs[k] = obs[k].astype(pd.CategoricalDtype(categories[k]))
77
+ return obs
78
+
79
+
80
+ def model_matrix_names(adata, formula, categories):
81
+ if adata.isbacked:
82
+ obs = adata[:1].to_memory().obs
83
+ else:
84
+ obs = adata.obs
85
+
86
+ if formula is None:
87
+ return list(obs.columns)
88
+
89
+ obs = replace_cols(obs, categories)
90
+ return list(model_matrix(formula, pd.DataFrame(obs)).columns)
91
+
92
+
93
+ def safe_model_matrix(obs, formula, categories):
94
+ if formula is None:
95
+ return obs
96
+
97
+ obs = replace_cols(obs, categories)
98
+ x = model_matrix(formula, pd.DataFrame(obs))
99
+ return torch.from_numpy(np.array(x).astype(np.float32))
100
+
101
+
102
+ def column_levels(obs):
103
+ categories = {}
104
+ for k in obs.columns:
105
+ obs_type = str(obs[k].dtype)
106
+ if obs_type in ["category", "object"]:
107
+ categories[k] = obs[k].unique()
108
+ return categories
109
+
110
+
111
+ def check_device():
112
+ return torch.device(
113
+ "cuda"
114
+ if torch.cuda.is_available()
115
+ else "mps" if torch.backends.mps.is_available() else "cpu"
116
+ )
117
+
118
+ def standardize_formula(formula: Union[str, dict], allowed_keys = None):
119
+ # The first element of allowed_keys should be the name of default parameter
120
+ if allowed_keys is None:
121
+ raise ValueError("Internal error: allowed_keys must be specified")
122
+ formula = {allowed_keys[0]: formula} if isinstance(formula, str) else formula
123
+
124
+ formula_keys = set(formula.keys())
125
+ allowed_keys = set(allowed_keys)
126
+
127
+ if not formula_keys & allowed_keys:
128
+ raise ValueError(f"formula must have at least one of the following keys: {allowed_keys}")
129
+
130
+ if extra_keys := formula_keys - allowed_keys:
131
+ warnings.warn(
132
+ f"Invalid formulas in dictionary will not be used: {extra_keys}",
133
+ UserWarning,
134
+ )
135
+
136
+ formula.update({k: '~ 1' for k in allowed_keys - formula_keys})
137
+ return formula
@@ -0,0 +1,123 @@
1
+ from . import formula as fl
2
+ from anndata import AnnData
3
+ from formulaic import model_matrix
4
+ import numpy as np
5
+ import pandas as pd
6
+ import scipy
7
+ import torch
8
+ import torch.utils.data as td
9
+
10
+
11
+ def formula_group_loader(
12
+ adata: AnnData,
13
+ formula=None,
14
+ grouping_variable=None,
15
+ chunk_size=int(1e4),
16
+ batch_size: int = None,
17
+ ):
18
+ device = fl.check_device()
19
+ if grouping_variable is None:
20
+ adata.obs["_copula_group"] = "shared_group"
21
+ grouping_variable = "_copula_group"
22
+ adata.obs["_copula_group"] = adata.obs["_copula_group"].astype("category")
23
+
24
+ if adata.isbacked:
25
+ ds = FormulaGroupViewDataset(
26
+ adata, formula, grouping_variable, chunk_size, device
27
+ )
28
+ dataloader = td.DataLoader(ds, batch_size=batch_size, collate_fn=stack_collate())
29
+ ds.x_names = fl.model_matrix_names(adata, formula, ds.categories)
30
+ else:
31
+ # convert sparse to dense matrix
32
+ y = adata.X
33
+ if isinstance(y, scipy.sparse._csc.csc_matrix):
34
+ y = y.todense()
35
+
36
+ # wrap the entire data into a dataset
37
+ x = model_matrix(formula, pd.DataFrame(adata.obs))
38
+ ds = td.StackDataset(
39
+ x=td.TensorDataset(
40
+ torch.tensor(np.array(x), dtype=torch.float32).to(device)
41
+ ),
42
+ y=td.TensorDataset(torch.tensor(y, dtype=torch.float32).to(device)),
43
+ groups=ListDataset(adata.obs[grouping_variable]),
44
+ )
45
+ ds.groups = list(adata.obs[grouping_variable].dtype.categories)
46
+ ds.x_names = list(x.columns)
47
+ dataloader = td.DataLoader(ds, batch_size=batch_size, collate_fn=stack_collate(pop=True))
48
+
49
+ return dataloader
50
+
51
+ def multiple_formula_group_loader(adata: AnnData, formulas: dict, grouping_variable=None,
52
+ chunk_size=int(1e4), batch_size: int = None):
53
+ dataloaders = {}
54
+ for key in formulas.keys():
55
+ dataloaders[key] = formula_group_loader(adata, formulas[key], grouping_variable, chunk_size, batch_size)
56
+ return dataloaders
57
+
58
+ class FormulaGroupViewDataset(td.Dataset):
59
+ def __init__(
60
+ self,
61
+ view,
62
+ formula=None,
63
+ grouping_variable=None,
64
+ chunk_size=int(1e4),
65
+ device=None,
66
+ ):
67
+ super().__init__()
68
+ self.device = device or fl.check_device()
69
+ self.formula = formula
70
+ self.categories = fl.column_levels(view.obs)
71
+ self.grouping_variable = grouping_variable
72
+ self.groups = list(self.categories[grouping_variable].dtype.categories)
73
+ self.len = len(view)
74
+ self.memberships = None
75
+ self.view = view
76
+ self.x = None
77
+ self.y = None
78
+ self.cur_range = range(0, min(self.len, chunk_size))
79
+
80
+ def __len__(self):
81
+ return self.len
82
+
83
+ def __getitem__(self, ix):
84
+ if self.x is None or ix not in self.cur_range:
85
+ self.cur_range = range(ix, min(ix + len(self.cur_range), self.len))
86
+ view_inmem = self.view[self.cur_range].to_memory()
87
+ self.memberships = view_inmem.obs[self.grouping_variable]
88
+ self.x = fl.safe_model_matrix(
89
+ view_inmem.obs, self.formula, self.categories
90
+ ).to(self.device)
91
+ self.y = torch.from_numpy(view_inmem.X.toarray().astype(np.float32)).to(
92
+ self.device
93
+ )
94
+ return {
95
+ "x": self.x[ix - self.cur_range[0]],
96
+ "y": self.y[ix - self.cur_range[0]],
97
+ "groups": self.memberships[ix - self.cur_range[0]],
98
+ }
99
+
100
+
101
+ class ListDataset(td.Dataset):
102
+ """
103
+ Simple DS to store groups
104
+ """
105
+
106
+ def __init__(self, list):
107
+ self.list = list
108
+
109
+ def __len__(self):
110
+ return len(self.list)
111
+
112
+ def __getitem__(self, idx):
113
+ return self.list[idx]
114
+
115
+ def stack_collate(pop=False, groups=True):
116
+ def f(batch):
117
+ x = torch.stack([sample["x"][0] if pop else sample["x"] for sample in batch])
118
+ y = torch.stack([sample["y"][0] if pop else sample["y"] for sample in batch])
119
+ if groups:
120
+ G = tuple([sample["groups"] for sample in batch])
121
+ return [x, y, G]
122
+ return [x, y]
123
+ return f
@@ -0,0 +1,39 @@
1
+ import anndata
2
+ import torch
3
+ import torch.utils.data as td
4
+
5
+ class SparseMatrixLoader:
6
+ def __init__(self, adata: anndata.AnnData, batch_size: int = None):
7
+ ds = SparseMatrixDataset(adata, batch_size)
8
+ self.loader = td.DataLoader(ds, batch_size=None)
9
+
10
+
11
+ class SparseMatrixDataset(td.IterableDataset):
12
+ def __init__(self, anndata: anndata.AnnData, batch_size: int = None):
13
+ self.n_rows = anndata.X.shape[0]
14
+ if batch_size is None:
15
+ batch_size = self.n_rows
16
+
17
+ self.sparse_matrix = anndata.X
18
+ self.batch_size = batch_size
19
+
20
+ def __iter__(self):
21
+ for i in range(0, self.n_rows, self.batch_size):
22
+ batch_indices = range(i, min(i + self.batch_size, self.n_rows))
23
+ batch_rows = self.sparse_matrix[batch_indices, :]
24
+
25
+ # Convert to sparse CSR tensor
26
+ batch_indices_rows, batch_indices_cols = batch_rows.nonzero()
27
+ batch_values = batch_rows.data
28
+
29
+ batch_sparse_tensor = torch.sparse_coo_tensor(
30
+ torch.tensor([batch_indices_rows, batch_indices_cols]),
31
+ torch.tensor(batch_values, dtype=torch.float32),
32
+ (len(batch_indices), self.sparse_matrix.shape[1]),
33
+ ).to_sparse_csr()
34
+
35
+ yield batch_sparse_tensor
36
+
37
+ def __len__(self):
38
+ return (self.n_rows + self.batch_size - 1) // self.batch_size
39
+
@@ -0,0 +1,65 @@
1
+ from .plot import (
2
+ plot_umap,
3
+ plot_hist,
4
+ compare_means,
5
+ compare_variances,
6
+ compare_standard_deviation,
7
+ compare_umap,
8
+ compare_pca,
9
+ )
10
+ from .aic_bic import compose_marginal_diagnose, compose_gcopula_diagnose
11
+ from .. import estimators as est
12
+
13
+ __all__ = [
14
+ "bernoulli_gcopula_diagnose",
15
+ "bernoulli_regression_diagnose",
16
+ "compare_means",
17
+ "compare_pca",
18
+ "compare_standard_deviation",
19
+ "compare_umap",
20
+ "compare_variances",
21
+ "gaussian_regression_diagnose",
22
+ "negbin_gcopula_diagnose",
23
+ "negbin_regression_diagnose",
24
+ "plot_hist",
25
+ "plot_pca",
26
+ "plot_umap",
27
+ "poisson_gcopula_diagnose",
28
+ "poisson_regression_diagnose",
29
+ "zinb_gcopula_diagnose",
30
+ "zinb_regression_diagnose",
31
+ "zip_regression_diagnose"
32
+ ]
33
+
34
+
35
+ ###############################################################################
36
+ ## Methods for calculating marginal/gaussian copula AIC/BIC
37
+ ###############################################################################
38
+
39
+ negbin_regression_diagnose = compose_marginal_diagnose(est.negbin.negbin_regression_likelihood,
40
+ allowed_keys=['mean', 'dispersion'])
41
+ negbin_gcopula_diagnose = compose_gcopula_diagnose(est.negbin.negbin_regression_likelihood,
42
+ est.negbin.negbin_uniformizer,
43
+ allowed_keys=['mean', 'dispersion'])
44
+ poisson_regression_diagnose = compose_marginal_diagnose(est.poisson.poisson_regression_likelihood,
45
+ allowed_keys=['mean'])
46
+ poisson_gcopula_diagnose = compose_gcopula_diagnose(est.poisson.poisson_regression_likelihood,
47
+ est.poisson.poisson_uniformizer,
48
+ allowed_keys=['mean'])
49
+ bernoulli_regression_diagnose = compose_marginal_diagnose(est.bernoulli.bernoulli_regression_likelihood,
50
+ allowed_keys=['mean'])
51
+ bernoulli_gcopula_diagnose = compose_gcopula_diagnose(est.bernoulli.bernoulli_regression_likelihood,
52
+ est.bernoulli.bernoulli_uniformizer,
53
+ allowed_keys=['mean'])
54
+ zinb_regression_diagnose = compose_marginal_diagnose(est.zero_inflated_negbin.zero_inflated_negbin_regression_likelihood,
55
+ allowed_keys=['mean', 'dispersion', 'zero_inflation'])
56
+ zinb_gcopula_diagnose = compose_gcopula_diagnose(est.zero_inflated_negbin.zero_inflated_negbin_regression_likelihood,
57
+ est.zero_inflated_negbin.zero_inflated_negbin_uniformizer,
58
+ allowed_keys=['mean', 'dispersion', 'zero_inflation'])
59
+ zip_regression_diagnose = compose_marginal_diagnose(est.zero_inflated_poisson.zero_inflated_poisson_regression_likelihood,
60
+ allowed_keys=['mean', 'zero_inflation'])
61
+ gaussian_regression_diagnose = compose_marginal_diagnose(est.gaussian.gaussian_regression_likelihood,
62
+ allowed_keys=['mean', 'sdev'])
63
+ gaussian_gcopula_diagnose = compose_gcopula_diagnose(est.gaussian.gaussian_regression_likelihood,
64
+ est.gaussian.gaussian_uniformizer,
65
+ allowed_keys=['mean', 'sdev'])
@@ -0,0 +1,119 @@
1
+ from .. import data
2
+ from anndata import AnnData
3
+ from formulaic import model_matrix
4
+ from scipy.stats import norm, multivariate_normal
5
+ from typing import Union
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch, scipy
9
+
10
+
11
+ def marginal_aic_bic(likelihood, params: dict, adata: AnnData,
12
+ formula: Union[str, dict], allowed_keys: set,
13
+ param_order: list = None, transform: list = None,
14
+ chunk_size: int = int(1e4), batch_size=512):
15
+ device = data.formula.check_device()
16
+ nsample = len(adata)
17
+ params = likelihood_unwrapper(params, param_order, transform).to(device)
18
+ nparam = len(params)
19
+
20
+ # create batches for likelihood calculation
21
+ formula = data.standardize_formula(formula, allowed_keys)
22
+ loader = data.multiple_formula_loader(
23
+ adata, formula, chunk_size=chunk_size, batch_size=batch_size
24
+ )
25
+ keys = list(loader.keys())
26
+ loaders = list(loader.values())
27
+ num_keys = len(keys)
28
+
29
+ ll = 0
30
+ with torch.no_grad():
31
+ for batches in zip(*loaders):
32
+ x = {keys[i]: batches[i][0].to(device) for i in range(num_keys)}
33
+ y = batches[0][1].to(device)
34
+ ll += -likelihood(params, x, y)
35
+ aic = 2 * nparam - 2 * ll
36
+ bic = np.log(nsample) * nparam - 2 * ll
37
+ return aic.cpu().item(), bic.cpu().item()
38
+
39
+
40
+ def gaussian_copula_aic_bic(uniformizer, params: dict, adata: AnnData,
41
+ formula: Union[str, dict], allowed_keys: set, copula_groups=None):
42
+ params = uniformizer_unwrapper(params)
43
+ covariance = params['covariance']
44
+ y = adata.X
45
+ if isinstance(y, scipy.sparse._csc.csc_matrix):
46
+ y = y.todense()
47
+ formula = data.standardize_formula(formula, allowed_keys)
48
+ X = {key: model_matrix(formula[key], pd.DataFrame(adata.obs)) for key in formula}
49
+ if copula_groups is not None:
50
+ memberships = adata.obs[copula_groups]
51
+ else:
52
+ copula_groups = "shared_group"
53
+ memberships = np.array(["shared_group"] * y.shape[0])
54
+
55
+ u = uniformizer(params, X, y)
56
+ groups = covariance.keys()
57
+ nparam = {
58
+ g: (np.sum(covariance[g] != 0) - covariance[g].shape[0]) / 2 for g in groups
59
+ }
60
+ aic = 0 # in the future may add group-wise AIC/BIC
61
+ bic = 0
62
+ for g in groups:
63
+ ix = np.where(memberships == g)[0]
64
+ z = norm().ppf(u[ix])
65
+ copula_ll = multivariate_normal.logpdf(
66
+ z, np.zeros(covariance[g].shape[0]), covariance[g]
67
+ )
68
+ marginal_ll = norm.logpdf(z)
69
+ aic += -2 * (np.sum(copula_ll) - np.sum(marginal_ll)) + 2 * nparam[g]
70
+ bic += (
71
+ -2 * (np.sum(copula_ll) - np.sum(marginal_ll))
72
+ + np.log(z.shape[0]) * nparam[g]
73
+ )
74
+ return aic, bic
75
+
76
+
77
+ def compose_marginal_diagnose(likelihood, allowed_keys: set, param_order: list = None, transform: list = None):
78
+ def diagnose(params: dict, adata: AnnData, formula: Union[str, dict],
79
+ chunk_size: int = int(1e4), batch_size=512):
80
+ return marginal_aic_bic(likelihood, params, adata, formula, allowed_keys,
81
+ param_order, transform, chunk_size, batch_size)
82
+ return diagnose
83
+
84
+
85
+ def compose_gcopula_diagnose(likelihood, uniformizer, allowed_keys: set,
86
+ param_order: list = None, transform: list = None):
87
+ def diagnose(params: dict, adata: AnnData, formula: str, copula_groups=None,
88
+ chunk_size: int = int(1e4), batch_size=512):
89
+ marginal_aic, marginal_bic = marginal_aic_bic(likelihood, params, adata, formula, allowed_keys,
90
+ param_order, transform, chunk_size, batch_size)
91
+ copula_aic, copula_bic = gaussian_copula_aic_bic(uniformizer, params, adata, formula, allowed_keys, copula_groups)
92
+ return marginal_aic, marginal_bic, copula_aic, copula_bic
93
+ return diagnose
94
+
95
+
96
+ ###############################################################################
97
+ ## Helper for converting params to match likelihood/uniformizer input format
98
+ ###############################################################################
99
+
100
+ def likelihood_unwrapper(params: dict, param_order: list = None, transform: list = None):
101
+ l = []
102
+ keys_to_process = [k for k in params if k != "covariance"] if param_order is None else param_order
103
+
104
+ for idx, k in enumerate(keys_to_process):
105
+ feature = params[k]
106
+ v = torch.Tensor(feature.values).reshape(1, feature.shape[0] * feature.shape[1])[0]
107
+ if transform is not None:
108
+ v = transform[idx](v)
109
+ l.append(v)
110
+
111
+ return torch.cat(l, dim=0)
112
+
113
+ def uniformizer_unwrapper(params):
114
+ params = params = {key: params[key].values if key!='covariance' else params[key] for key in params}
115
+ if not isinstance(params['covariance'], dict):
116
+ params['covariance'] = {'shared_group': params['covariance'].values}
117
+ else:
118
+ params['covariance'] = {key: params['covariance'][key].values for key in params['covariance']}
119
+ return params