isovae 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
isovae-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 IsoVAE developers
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
isovae-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,125 @@
1
+ Metadata-Version: 2.4
2
+ Name: isovae
3
+ Version: 0.1.0
4
+ Summary: IsoVAE: isoform-usage prediction and long-read isoform-usage denoising for single-cell RNA-seq
5
+ Author: IsoVAE developers
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/your-username/IsoVAE
8
+ Project-URL: Documentation, https://your-username.github.io/IsoVAE/
9
+ Project-URL: Repository, https://github.com/your-username/IsoVAE
10
+ Project-URL: Issues, https://github.com/your-username/IsoVAE/issues
11
+ Keywords: single-cell RNA-seq,isoform usage,long-read RNA-seq,variational autoencoder,denoising
12
+ Requires-Python: >=3.10
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: numpy>=1.23
16
+ Requires-Dist: pandas>=1.5
17
+ Requires-Dist: scipy>=1.9
18
+ Requires-Dist: scikit-learn>=1.2
19
+ Requires-Dist: anndata>=0.9
20
+ Requires-Dist: torch>=2.0
21
+ Requires-Dist: matplotlib>=3.6
22
+ Requires-Dist: seaborn>=0.12
23
+ Provides-Extra: docs
24
+ Requires-Dist: mkdocs>=1.5; extra == "docs"
25
+ Requires-Dist: mkdocs-material>=9.5; extra == "docs"
26
+ Requires-Dist: mkdocstrings[python]>=0.24; extra == "docs"
27
+ Provides-Extra: dev
28
+ Requires-Dist: pytest; extra == "dev"
29
+ Requires-Dist: ruff; extra == "dev"
30
+ Dynamic: license-file
31
+
32
+ # IsoVAE
33
+
34
+ IsoVAE is a Python package for single-cell isoform-usage analysis. It supports:
35
+
36
+ 1. **Isoform-usage prediction** from short-read single-cell gene-expression profiles.
37
+ 2. **Long-read isoform-usage denoising** from sparse long-read isoform count matrices.
38
+
39
+ IsoVAE models **within-gene isoform usage proportions**, not absolute transcript abundance.
40
+
41
+ ## Installation
42
+
43
+ ```bash
44
+ pip install isovae
45
+ ```
46
+
47
+ For local development:
48
+
49
+ ```bash
50
+ git clone https://github.com/your-username/IsoVAE.git
51
+ cd IsoVAE
52
+ pip install -e .
53
+ ```
54
+
55
+ ## Quick start
56
+
57
+ ```python
58
+ import scanpy as sc
59
+ from isovae import (
60
+ load_artifact,
61
+ reconstruct_preprocessor_from_training_data,
62
+ predict_isoform_usage,
63
+ denoise_isoform_usage,
64
+ )
65
+
66
+ model_path = "path/to/vae_xda_model.pt"
67
+
68
+ gene_train = sc.read("path/to/training_gene_matrix.h5ad")
69
+ iso_train = sc.read("path/to/training_isoform_matrix.h5ad")
70
+
71
+ preprocessor = reconstruct_preprocessor_from_training_data(
72
+ model_path,
73
+ adata_gene_train=gene_train,
74
+ adata_iso_train=iso_train,
75
+ seed=42,
76
+ )
77
+
78
+ artifact = load_artifact(model_path, preprocessor=preprocessor, device="cpu")
79
+
80
+ # Predict isoform usage from short-read data.
81
+ gene_query = sc.read("path/to/query_gene_matrix.h5ad")
82
+ pred_usage, pred_meta = predict_isoform_usage(artifact, gene_query)
83
+ pred_usage.to_csv("predicted_isoform_usage.csv")
84
+
85
+ # Denoise long-read isoform usage.
86
+ iso_query = sc.read("path/to/query_isoform_matrix.h5ad")
87
+ denoised_usage, noisy_usage, denoise_meta = denoise_isoform_usage(artifact, iso_query)
88
+ denoised_usage.to_csv("denoised_isoform_usage.csv")
89
+ ```
90
+
91
+ ## Documentation
92
+
93
+ The documentation source is in `docs/` and can be built with MkDocs:
94
+
95
+ ```bash
96
+ pip install -e ".[docs]"
97
+ mkdocs serve
98
+ ```
99
+
100
+ To deploy to GitHub Pages:
101
+
102
+ ```bash
103
+ mkdocs gh-deploy
104
+ ```
105
+
106
+ See `docs/deployment.md` for deployment instructions for GitHub Pages, Read the Docs, Netlify and Vercel.
107
+
108
+ ## Repository layout
109
+
110
+ ```text
111
+ .
112
+ ├── src/isovae/ # Python package
113
+ ├── docs/ # Documentation source
114
+ ├── mkdocs.yml # Documentation configuration
115
+ ├── pyproject.toml # Package metadata
116
+ ├── requirements.txt
117
+ ├── LICENSE
118
+ └── README.md
119
+ ```
120
+
121
+ Large data files, AnnData objects, model checkpoints and manuscript outputs are not included in the package.
122
+
123
+ ## Citation
124
+
125
+ If you use IsoVAE, please cite the accompanying manuscript after publication.
isovae-0.1.0/README.md ADDED
@@ -0,0 +1,94 @@
1
+ # IsoVAE
2
+
3
+ IsoVAE is a Python package for single-cell isoform-usage analysis. It supports:
4
+
5
+ 1. **Isoform-usage prediction** from short-read single-cell gene-expression profiles.
6
+ 2. **Long-read isoform-usage denoising** from sparse long-read isoform count matrices.
7
+
8
+ IsoVAE models **within-gene isoform usage proportions**, not absolute transcript abundance.
9
+
10
+ ## Installation
11
+
12
+ ```bash
13
+ pip install isovae
14
+ ```
15
+
16
+ For local development:
17
+
18
+ ```bash
19
+ git clone https://github.com/your-username/IsoVAE.git
20
+ cd IsoVAE
21
+ pip install -e .
22
+ ```
23
+
24
+ ## Quick start
25
+
26
+ ```python
27
+ import scanpy as sc
28
+ from isovae import (
29
+ load_artifact,
30
+ reconstruct_preprocessor_from_training_data,
31
+ predict_isoform_usage,
32
+ denoise_isoform_usage,
33
+ )
34
+
35
+ model_path = "path/to/vae_xda_model.pt"
36
+
37
+ gene_train = sc.read("path/to/training_gene_matrix.h5ad")
38
+ iso_train = sc.read("path/to/training_isoform_matrix.h5ad")
39
+
40
+ preprocessor = reconstruct_preprocessor_from_training_data(
41
+ model_path,
42
+ adata_gene_train=gene_train,
43
+ adata_iso_train=iso_train,
44
+ seed=42,
45
+ )
46
+
47
+ artifact = load_artifact(model_path, preprocessor=preprocessor, device="cpu")
48
+
49
+ # Predict isoform usage from short-read data.
50
+ gene_query = sc.read("path/to/query_gene_matrix.h5ad")
51
+ pred_usage, pred_meta = predict_isoform_usage(artifact, gene_query)
52
+ pred_usage.to_csv("predicted_isoform_usage.csv")
53
+
54
+ # Denoise long-read isoform usage.
55
+ iso_query = sc.read("path/to/query_isoform_matrix.h5ad")
56
+ denoised_usage, noisy_usage, denoise_meta = denoise_isoform_usage(artifact, iso_query)
57
+ denoised_usage.to_csv("denoised_isoform_usage.csv")
58
+ ```
59
+
60
+ ## Documentation
61
+
62
+ The documentation source is in `docs/` and can be built with MkDocs:
63
+
64
+ ```bash
65
+ pip install -e ".[docs]"
66
+ mkdocs serve
67
+ ```
68
+
69
+ To deploy to GitHub Pages:
70
+
71
+ ```bash
72
+ mkdocs gh-deploy
73
+ ```
74
+
75
+ See `docs/deployment.md` for deployment instructions for GitHub Pages, Read the Docs, Netlify and Vercel.
76
+
77
+ ## Repository layout
78
+
79
+ ```text
80
+ .
81
+ ├── src/isovae/ # Python package
82
+ ├── docs/ # Documentation source
83
+ ├── mkdocs.yml # Documentation configuration
84
+ ├── pyproject.toml # Package metadata
85
+ ├── requirements.txt
86
+ ├── LICENSE
87
+ └── README.md
88
+ ```
89
+
90
+ Large data files, AnnData objects, model checkpoints and manuscript outputs are not included in the package.
91
+
92
+ ## Citation
93
+
94
+ If you use IsoVAE, please cite the accompanying manuscript after publication.
@@ -0,0 +1,54 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "isovae"
7
+ version = "0.1.0"
8
+ description = "IsoVAE: isoform-usage prediction and long-read isoform-usage denoising for single-cell RNA-seq"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = "MIT"
12
+ authors = [
13
+ {name = "IsoVAE developers"}
14
+ ]
15
+ keywords = [
16
+ "single-cell RNA-seq",
17
+ "isoform usage",
18
+ "long-read RNA-seq",
19
+ "variational autoencoder",
20
+ "denoising"
21
+ ]
22
+ dependencies = [
23
+ "numpy>=1.23",
24
+ "pandas>=1.5",
25
+ "scipy>=1.9",
26
+ "scikit-learn>=1.2",
27
+ "anndata>=0.9",
28
+ "torch>=2.0",
29
+ "matplotlib>=3.6",
30
+ "seaborn>=0.12"
31
+ ]
32
+
33
+ [project.optional-dependencies]
34
+ docs = [
35
+ "mkdocs>=1.5",
36
+ "mkdocs-material>=9.5",
37
+ "mkdocstrings[python]>=0.24"
38
+ ]
39
+ dev = [
40
+ "pytest",
41
+ "ruff"
42
+ ]
43
+
44
+ [tool.setuptools.packages.find]
45
+ where = ["src"]
46
+
47
+ [tool.ruff]
48
+ line-length = 100
49
+
50
+ [project.urls]
51
+ Homepage = "https://github.com/your-username/IsoVAE"
52
+ Documentation = "https://your-username.github.io/IsoVAE/"
53
+ Repository = "https://github.com/your-username/IsoVAE"
54
+ Issues = "https://github.com/your-username/IsoVAE/issues"
isovae-0.1.0/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,39 @@
1
+ """IsoVAE public API.
2
+
3
+ IsoVAE predicts gene-wise isoform usage from short-read single-cell gene
4
+ expression and denoises sparse long-read isoform-usage measurements.
5
+ """
6
+
7
+ from .data import (
8
+ IsoVAEPreprocessor,
9
+ align_paired_cells,
10
+ counts_to_gene_usage,
11
+ prepare_paired_data,
12
+ )
13
+ from .inference import (
14
+ IsoVAEArtifact,
15
+ denoise_isoform_usage,
16
+ load_artifact,
17
+ predict_isoform_usage,
18
+ reconstruct_preprocessor_from_training_data,
19
+ )
20
+ from .model import IsoVAEConfig, IsoVAEModel
21
+ from .utils import select_device, set_seed
22
+
23
+ __all__ = [
24
+ "IsoVAEArtifact",
25
+ "IsoVAEConfig",
26
+ "IsoVAEModel",
27
+ "IsoVAEPreprocessor",
28
+ "align_paired_cells",
29
+ "counts_to_gene_usage",
30
+ "denoise_isoform_usage",
31
+ "load_artifact",
32
+ "predict_isoform_usage",
33
+ "prepare_paired_data",
34
+ "reconstruct_preprocessor_from_training_data",
35
+ "select_device",
36
+ "set_seed",
37
+ ]
38
+
39
+ __version__ = "0.1.0"
@@ -0,0 +1,268 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
5
+
6
+ import anndata as ad
7
+ import numpy as np
8
+ import pandas as pd
9
+ import scipy.sparse as sp
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn.preprocessing import StandardScaler
12
+
13
+ from .utils import set_seed
14
+
15
+ Array = np.ndarray
16
+
17
+
18
+ def _as_array(x: Any) -> Array:
19
+ if sp.issparse(x):
20
+ return x.toarray()
21
+ return np.asarray(x)
22
+
23
+
24
+ def strip_barcode_suffix(index: pd.Index) -> pd.Index:
25
+ """Normalize 10x-style cell barcodes by removing a trailing ``-1`` suffix."""
26
+ return index.astype(str).str.replace(r"-1$", "", regex=True)
27
+
28
+
29
+ def align_paired_cells(
30
+ adata_gene: ad.AnnData,
31
+ adata_iso: ad.AnnData,
32
+ strict: bool = True,
33
+ ) -> Tuple[ad.AnnData, ad.AnnData]:
34
+ """Align paired short-read and long-read AnnData objects by cell barcode."""
35
+ gene_barcodes = strip_barcode_suffix(adata_gene.obs_names)
36
+ iso_barcodes = strip_barcode_suffix(adata_iso.obs_names)
37
+
38
+ gene_pos = {barcode: i for i, barcode in enumerate(gene_barcodes)}
39
+ iso_pos = {barcode: i for i, barcode in enumerate(iso_barcodes)}
40
+ common = sorted(set(gene_pos) & set(iso_pos))
41
+ if strict and not common:
42
+ raise ValueError("No common cells found after barcode normalization.")
43
+
44
+ gene_idx = [gene_pos[barcode] for barcode in common]
45
+ iso_idx = [iso_pos[barcode] for barcode in common]
46
+ gene_aligned = adata_gene[gene_idx, :].copy()
47
+ iso_aligned = adata_iso[iso_idx, :].copy()
48
+ normalized_index = pd.Index(common, name="barcode")
49
+ gene_aligned.obs["barcode_raw"] = gene_aligned.obs_names.astype(str)
50
+ iso_aligned.obs["barcode_raw"] = iso_aligned.obs_names.astype(str)
51
+ gene_aligned.obs_names = normalized_index
52
+ iso_aligned.obs_names = normalized_index
53
+ return gene_aligned, iso_aligned
54
+
55
+
56
+ def make_unique_gene_symbol_view(adata_gene: ad.AnnData) -> ad.AnnData:
57
+ """Return a copy with unique gene-symbol columns."""
58
+ if "gene_symbol" in adata_gene.var.columns:
59
+ symbols = adata_gene.var["gene_symbol"].astype(str).values
60
+ else:
61
+ symbols = adata_gene.var_names.astype(str).values
62
+ adata_gene = adata_gene.copy()
63
+ adata_gene.var["gene_symbol"] = symbols
64
+ keep = ~pd.Index(symbols).duplicated(keep="first")
65
+ return adata_gene[:, keep].copy()
66
+
67
+
68
+ def _col_var(x: Any) -> Array:
69
+ if sp.issparse(x):
70
+ mean = np.asarray(x.mean(axis=0)).ravel()
71
+ sq_mean = np.asarray(x.power(2).mean(axis=0)).ravel()
72
+ return sq_mean - mean**2
73
+ return np.asarray(x).var(axis=0)
74
+
75
+
76
+ def normalize_gene_counts(x: Any) -> Array:
77
+ """Library-size normalize selected short-read gene counts and apply log1p."""
78
+ if sp.issparse(x):
79
+ lib = np.asarray(x.sum(axis=1)).ravel().astype(np.float32)
80
+ scale = (1e4 / np.maximum(lib, 1e-6)).astype(np.float32)
81
+ return x.multiply(scale[:, None]).log1p().toarray().astype(np.float32)
82
+
83
+ x = np.asarray(x, dtype=np.float32)
84
+ lib = x.sum(axis=1, keepdims=True)
85
+ return np.log1p((x / np.maximum(lib, 1e-6)) * 1e4).astype(np.float32)
86
+
87
+
88
+ def counts_to_gene_usage(
89
+ y_counts: Any,
90
+ isoform_gene: Sequence[str],
91
+ ) -> Tuple[Array, List[Array], Array]:
92
+ """Convert transcript counts to within-gene isoform-usage proportions."""
93
+ y_counts = _as_array(y_counts).astype(np.float32)
94
+ gene_to_idx: Dict[str, List[int]] = {}
95
+ for j, g in enumerate(np.asarray(isoform_gene).astype(str)):
96
+ gene_to_idx.setdefault(g, []).append(j)
97
+
98
+ genes = np.array(list(gene_to_idx.keys()), dtype=object)
99
+ groups = [np.array(v, dtype=np.int64) for v in gene_to_idx.values()]
100
+ y_usage = np.zeros_like(y_counts, dtype=np.float32)
101
+ for idx in groups:
102
+ denom = y_counts[:, idx].sum(axis=1, keepdims=True)
103
+ rows = np.where(denom[:, 0] > 0)[0]
104
+ if rows.size:
105
+ y_usage[np.ix_(rows, idx)] = y_counts[np.ix_(rows, idx)] / denom[rows]
106
+ return y_usage, groups, genes
107
+
108
+
109
+ @dataclass
110
+ class IsoVAEPreprocessor:
111
+ """Feature names and scaler needed for short-read-only prediction."""
112
+
113
+ genes: List[str]
114
+ isoforms: List[str]
115
+ isoform_gene: List[str]
116
+ gene_groups: List[Array]
117
+ scaler: StandardScaler
118
+
119
+ def transform_gene_adata(self, adata_gene: ad.AnnData) -> Tuple[Array, int]:
120
+ """Extract and scale the model input genes from a short-read AnnData object."""
121
+ adata_gene = make_unique_gene_symbol_view(adata_gene)
122
+ symbols = adata_gene.var["gene_symbol"].astype(str).values
123
+ symbol_to_pos = {g: i for i, g in enumerate(symbols)}
124
+ cols = []
125
+ n_found = 0
126
+ for gene in self.genes:
127
+ if gene in symbol_to_pos:
128
+ cols.append(adata_gene.X[:, symbol_to_pos[gene]])
129
+ n_found += 1
130
+ else:
131
+ cols.append(sp.csr_matrix((adata_gene.n_obs, 1), dtype=np.float32))
132
+ x_raw = sp.hstack(cols, format="csr") if sp.issparse(cols[0]) else np.column_stack(cols)
133
+ x = normalize_gene_counts(x_raw)
134
+ return self.scaler.transform(x).astype(np.float32), n_found
135
+
136
+ def extract_iso_counts(self, adata_iso: ad.AnnData) -> Tuple[Array, int]:
137
+ """Extract model isoform-count features from a long-read AnnData object."""
138
+ transcript = (
139
+ adata_iso.var["transcript_id"].astype(str).values
140
+ if "transcript_id" in adata_iso.var.columns
141
+ else adata_iso.var_names.astype(str).values
142
+ )
143
+ tx_to_pos = {t: i for i, t in enumerate(transcript)}
144
+ cols = []
145
+ n_found = 0
146
+ for tx in self.isoforms:
147
+ if tx in tx_to_pos:
148
+ cols.append(adata_iso.X[:, tx_to_pos[tx]])
149
+ n_found += 1
150
+ else:
151
+ cols.append(sp.csr_matrix((adata_iso.n_obs, 1), dtype=np.float32))
152
+ y = sp.hstack(cols, format="csr") if sp.issparse(cols[0]) else np.column_stack(cols)
153
+ return _as_array(y).astype(np.float32), n_found
154
+
155
+
156
+ def prepare_paired_data(
157
+ adata_gene: ad.AnnData,
158
+ adata_iso: ad.AnnData,
159
+ gene_hvg: int = 1200,
160
+ min_iso_cells: int = 10,
161
+ test_size: float = 0.20,
162
+ val_size_within_train: float = 0.20,
163
+ seed: int = 42,
164
+ ) -> Dict[str, Any]:
165
+ """Prepare paired data for training/evaluation.
166
+
167
+ Feature selection and scaler fitting are performed using training cells only.
168
+ """
169
+ set_seed(seed)
170
+ adata_gene, adata_iso = align_paired_cells(adata_gene, adata_iso)
171
+ adata_gene = make_unique_gene_symbol_view(adata_gene)
172
+
173
+ all_idx = np.arange(adata_gene.n_obs)
174
+ train_idx, test_idx = train_test_split(all_idx, test_size=test_size, random_state=seed)
175
+ train_idx, val_idx = train_test_split(
176
+ train_idx, test_size=val_size_within_train, random_state=seed
177
+ )
178
+
179
+ gene_symbols = adata_gene.var["gene_symbol"].astype(str).values
180
+ iso_gene_all = adata_iso.var["gene_id"].astype(str).values
181
+ common_genes = np.intersect1d(np.unique(gene_symbols), np.unique(iso_gene_all))
182
+ if common_genes.size == 0:
183
+ raise ValueError("No overlap between short-read gene_symbol and LR isoform gene_id.")
184
+
185
+ common_mask = np.isin(gene_symbols, common_genes)
186
+ adata_gene_common = adata_gene[:, common_mask]
187
+ var = _col_var(adata_gene_common.X[train_idx, :])
188
+ top_idx = np.argsort(var)[::-1][: min(gene_hvg, adata_gene_common.n_vars)]
189
+ selected_genes = adata_gene_common.var["gene_symbol"].astype(str).values[top_idx]
190
+
191
+ iso_mask = np.isin(iso_gene_all, selected_genes)
192
+ adata_iso_sel = adata_iso[:, iso_mask].copy()
193
+ iso_nonzero_cells = np.asarray((adata_iso_sel.X[train_idx, :] > 0).sum(axis=0)).ravel()
194
+ adata_iso_sel = adata_iso_sel[:, iso_nonzero_cells >= min_iso_cells].copy()
195
+
196
+ iso_gene = adata_iso_sel.var["gene_id"].astype(str).values
197
+ gene_to_idx: Dict[str, List[int]] = {}
198
+ for j, g in enumerate(iso_gene):
199
+ gene_to_idx.setdefault(g, []).append(j)
200
+ genes_keep = [g for g, idx in gene_to_idx.items() if len(idx) >= 2]
201
+ if not genes_keep:
202
+ raise ValueError("No genes with >=2 isoforms after filtering.")
203
+
204
+ iso_keep_idx = np.concatenate([gene_to_idx[g] for g in genes_keep]).astype(np.int64)
205
+ adata_iso_sel = adata_iso_sel[:, iso_keep_idx].copy()
206
+ isoform_gene = adata_iso_sel.var["gene_id"].astype(str).values
207
+ isoforms = (
208
+ adata_iso_sel.var["transcript_id"].astype(str).values
209
+ if "transcript_id" in adata_iso_sel.var.columns
210
+ else adata_iso_sel.var_names.astype(str).values
211
+ )
212
+ y_counts = _as_array(adata_iso_sel.X).astype(np.float32)
213
+ y_usage, gene_groups, genes = counts_to_gene_usage(y_counts, isoform_gene)
214
+ genes = genes.astype(str)
215
+
216
+ # Build short-read inputs in retained gene-group order.
217
+ # Manual extraction before scaler is fitted.
218
+ adata_gene_u = make_unique_gene_symbol_view(adata_gene)
219
+ symbols = adata_gene_u.var["gene_symbol"].astype(str).values
220
+ symbol_to_pos = {g: i for i, g in enumerate(symbols)}
221
+ cols = []
222
+ n_gene_found = 0
223
+ for g in genes:
224
+ if g in symbol_to_pos:
225
+ cols.append(adata_gene_u.X[:, symbol_to_pos[g]])
226
+ n_gene_found += 1
227
+ else:
228
+ cols.append(sp.csr_matrix((adata_gene_u.n_obs, 1), dtype=np.float32))
229
+ x_raw = sp.hstack(cols, format="csr") if sp.issparse(cols[0]) else np.column_stack(cols)
230
+ x_norm = normalize_gene_counts(x_raw)
231
+ scaler = StandardScaler().fit(x_norm[train_idx])
232
+ x_all = scaler.transform(x_norm).astype(np.float32)
233
+
234
+ preprocessor = IsoVAEPreprocessor(
235
+ genes=list(map(str, genes)),
236
+ isoforms=list(map(str, isoforms)),
237
+ isoform_gene=list(map(str, isoform_gene)),
238
+ gene_groups=gene_groups,
239
+ scaler=scaler,
240
+ )
241
+ return {
242
+ "x_train": x_all[train_idx],
243
+ "x_val": x_all[val_idx],
244
+ "x_test": x_all[test_idx],
245
+ "y_train_usage": y_usage[train_idx],
246
+ "y_val_usage": y_usage[val_idx],
247
+ "y_test_usage": y_usage[test_idx],
248
+ "y_train_counts": y_counts[train_idx],
249
+ "y_val_counts": y_counts[val_idx],
250
+ "y_test_counts": y_counts[test_idx],
251
+ "train_obs_names": adata_gene.obs_names[train_idx].astype(str).tolist(),
252
+ "val_obs_names": adata_gene.obs_names[val_idx].astype(str).tolist(),
253
+ "test_obs_names": adata_gene.obs_names[test_idx].astype(str).tolist(),
254
+ "preprocessor": preprocessor,
255
+ "gene_groups": gene_groups,
256
+ "genes_final": genes.astype(str),
257
+ "isoforms_final": np.asarray(isoforms).astype(str),
258
+ "isoform_gene": np.asarray(isoform_gene).astype(str),
259
+ "meta": {
260
+ "n_cells": int(adata_gene.n_obs),
261
+ "n_train": int(len(train_idx)),
262
+ "n_val": int(len(val_idx)),
263
+ "n_test": int(len(test_idx)),
264
+ "n_gene_groups": int(len(gene_groups)),
265
+ "n_isoforms_output": int(len(isoforms)),
266
+ "n_gene_features_found": int(n_gene_found),
267
+ },
268
+ }
@@ -0,0 +1,269 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Dict, List, Optional, Sequence, Tuple
6
+
7
+ import anndata as ad
8
+ import numpy as np
9
+ import pandas as pd
10
+ import scipy.sparse as sp
11
+ import torch
12
+ from sklearn.model_selection import train_test_split
13
+ from sklearn.preprocessing import StandardScaler
14
+
15
+ from .data import (
16
+ IsoVAEPreprocessor,
17
+ align_paired_cells,
18
+ counts_to_gene_usage,
19
+ make_unique_gene_symbol_view,
20
+ normalize_gene_counts,
21
+ )
22
+ from .model import (
23
+ IsoVAEModel,
24
+ iso_encoder_input_from_counts,
25
+ logits_to_usage,
26
+ make_config_from_checkpoint,
27
+ )
28
+ from .utils import select_device
29
+
30
+ Array = np.ndarray
31
+
32
+
33
+ @dataclass
34
+ class IsoVAEArtifact:
35
+ """Loaded model plus feature metadata.
36
+
37
+ A fitted ``preprocessor`` is required for short-read-only prediction.
38
+ Older checkpoints contain feature names but not the scaler; in that case,
39
+ pass a preprocessor reconstructed from the original training data.
40
+ """
41
+
42
+ model: IsoVAEModel
43
+ genes: List[str]
44
+ isoforms: List[str]
45
+ isoform_gene: List[str]
46
+ gene_groups: List[Array]
47
+ preprocessor: Optional[IsoVAEPreprocessor] = None
48
+ device: str = "cpu"
49
+
50
+ def to(self, device: str) -> "IsoVAEArtifact":
51
+ self.device = device
52
+ self.model.to(device)
53
+ self.model.eval()
54
+ return self
55
+
56
+
57
+ def load_artifact(
58
+ checkpoint: str | Path,
59
+ preprocessor: Optional[IsoVAEPreprocessor] = None,
60
+ device: Optional[str] = None,
61
+ ) -> IsoVAEArtifact:
62
+ """Load an IsoVAE checkpoint saved by the training scripts."""
63
+ device = device or select_device(prefer_cuda=True)
64
+ ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
65
+ genes = list(map(str, ckpt["genes_final"]))
66
+ isoforms = list(map(str, ckpt["isoforms_final"]))
67
+ isoform_gene = list(map(str, ckpt["isoform_gene"]))
68
+ groups = [np.asarray(g, dtype=np.int64) for g in ckpt["gene_groups"]]
69
+ config = make_config_from_checkpoint(ckpt.get("model_config", {}), ckpt.get("model_state"))
70
+ model = IsoVAEModel(
71
+ n_gene_inputs=len(genes),
72
+ n_isoforms=len(isoforms),
73
+ n_gene_groups=len(groups),
74
+ config=config,
75
+ )
76
+ model.load_state_dict(ckpt["model_state"], strict=True)
77
+ model.to(device)
78
+ model.eval()
79
+
80
+ if preprocessor is not None:
81
+ # Ensure model and preprocessor features agree.
82
+ if list(preprocessor.genes) != genes:
83
+ raise ValueError("Preprocessor genes do not match checkpoint genes.")
84
+ if list(preprocessor.isoforms) != isoforms:
85
+ raise ValueError("Preprocessor isoforms do not match checkpoint isoforms.")
86
+
87
+ return IsoVAEArtifact(
88
+ model=model,
89
+ genes=genes,
90
+ isoforms=isoforms,
91
+ isoform_gene=isoform_gene,
92
+ gene_groups=groups,
93
+ preprocessor=preprocessor,
94
+ device=device,
95
+ )
96
+
97
+
98
+ def reconstruct_preprocessor_from_training_data(
99
+ checkpoint: str | Path,
100
+ adata_gene_train: ad.AnnData,
101
+ adata_iso_train: Optional[ad.AnnData] = None,
102
+ seed: int = 42,
103
+ test_size: float = 0.20,
104
+ val_size_within_train: float = 0.20,
105
+ ) -> IsoVAEPreprocessor:
106
+ """Reconstruct preprocessing metadata for older checkpoints.
107
+
108
+ The final manuscript checkpoint stores feature names and model weights, but
109
+ not the fitted ``StandardScaler``. This helper rebuilds the scaler from the
110
+ original paired training data using the same deterministic split. If
111
+ ``adata_iso_train`` is provided, cells are first aligned exactly as in
112
+ training.
113
+ """
114
+ ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
115
+ genes = list(map(str, ckpt["genes_final"]))
116
+ isoforms = list(map(str, ckpt["isoforms_final"]))
117
+ isoform_gene = list(map(str, ckpt["isoform_gene"]))
118
+ groups = [np.asarray(g, dtype=np.int64) for g in ckpt["gene_groups"]]
119
+
120
+ if adata_iso_train is not None:
121
+ adata_gene_train, _ = align_paired_cells(adata_gene_train, adata_iso_train)
122
+ adata_gene_train = make_unique_gene_symbol_view(adata_gene_train)
123
+ symbols = adata_gene_train.var["gene_symbol"].astype(str).values
124
+ symbol_to_pos = {g: i for i, g in enumerate(symbols)}
125
+ cols = []
126
+ for gene in genes:
127
+ if gene in symbol_to_pos:
128
+ cols.append(adata_gene_train.X[:, symbol_to_pos[gene]])
129
+ else:
130
+ cols.append(sp.csr_matrix((adata_gene_train.n_obs, 1), dtype=np.float32))
131
+ x_raw = sp.hstack(cols, format="csr") if sp.issparse(cols[0]) else np.column_stack(cols)
132
+ x_norm = normalize_gene_counts(x_raw)
133
+
134
+ all_idx = np.arange(adata_gene_train.n_obs)
135
+ train_idx, _ = train_test_split(all_idx, test_size=test_size, random_state=seed)
136
+ train_idx, _ = train_test_split(
137
+ train_idx, test_size=val_size_within_train, random_state=seed
138
+ )
139
+ scaler = StandardScaler().fit(x_norm[train_idx])
140
+ return IsoVAEPreprocessor(
141
+ genes=genes,
142
+ isoforms=isoforms,
143
+ isoform_gene=isoform_gene,
144
+ gene_groups=groups,
145
+ scaler=scaler,
146
+ )
147
+
148
+
149
+ @torch.no_grad()
150
+ def predict_usage_array(
151
+ model: IsoVAEModel,
152
+ x: Array,
153
+ groups: Sequence[Array],
154
+ device: str = "cpu",
155
+ batch_size: int = 512,
156
+ ) -> Array:
157
+ """Predict isoform usage from a scaled gene-expression matrix."""
158
+ model.eval()
159
+ preds: List[Array] = []
160
+ for start in range(0, x.shape[0], batch_size):
161
+ xb = torch.from_numpy(x[start : start + batch_size].astype(np.float32)).to(device)
162
+ logits, _, _ = model.predict_from_gene(xb)
163
+ preds.append(logits_to_usage(logits, groups))
164
+ return np.vstack(preds)
165
+
166
+
167
+ @torch.no_grad()
168
+ def denoise_usage_array(
169
+ model: IsoVAEModel,
170
+ y_counts: Array,
171
+ groups: Sequence[Array],
172
+ device: str = "cpu",
173
+ keep_rate: Optional[float] = None,
174
+ batch_size: int = 512,
175
+ ) -> Tuple[Array, Array]:
176
+ """Denoise long-read isoform counts into within-gene isoform usage.
177
+
178
+ If ``keep_rate`` is provided, counts are first binomially downsampled. This
179
+ is useful for simulated denoising experiments.
180
+ """
181
+ model.eval()
182
+ group_tensors = [torch.as_tensor(g, dtype=torch.long, device=device) for g in groups]
183
+ preds: List[Array] = []
184
+ noisy_counts: List[Array] = []
185
+ for start in range(0, y_counts.shape[0], batch_size):
186
+ yc = torch.from_numpy(y_counts[start : start + batch_size].astype(np.float32)).to(device)
187
+ if keep_rate is None:
188
+ yn = yc
189
+ else:
190
+ yn = torch.binomial(torch.clamp(yc, min=0.0), torch.full_like(yc, float(keep_rate)))
191
+ iso_in = iso_encoder_input_from_counts(yn, group_tensors)
192
+ logits, _, _ = model.denoise_from_iso(iso_in)
193
+ preds.append(logits_to_usage(logits, groups))
194
+ noisy_counts.append(yn.detach().cpu().numpy().astype(np.float32))
195
+ return np.vstack(preds), np.vstack(noisy_counts)
196
+
197
+
198
+ def predict_isoform_usage(
199
+ artifact: IsoVAEArtifact,
200
+ adata_gene: ad.AnnData,
201
+ batch_size: int = 512,
202
+ ) -> Tuple[pd.DataFrame, Dict[str, int]]:
203
+ """Predict isoform usage from short-read-only AnnData."""
204
+ if artifact.preprocessor is None:
205
+ raise ValueError(
206
+ "A fitted IsoVAEPreprocessor is required for AnnData prediction. "
207
+ "Reconstruct it from the paired training data or load one saved with your model."
208
+ )
209
+ x, n_found = artifact.preprocessor.transform_gene_adata(adata_gene)
210
+ pred = predict_usage_array(
211
+ artifact.model, x, artifact.gene_groups, device=artifact.device, batch_size=batch_size
212
+ )
213
+ df = pd.DataFrame(pred, index=adata_gene.obs_names.astype(str), columns=artifact.isoforms)
214
+ return df, {"n_gene_features_found": int(n_found), "n_gene_features_total": len(artifact.genes)}
215
+
216
+
217
+ def denoise_isoform_usage(
218
+ artifact: IsoVAEArtifact,
219
+ adata_iso: ad.AnnData,
220
+ keep_rate: Optional[float] = None,
221
+ batch_size: int = 512,
222
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, Dict[str, int]]:
223
+ """Denoise long-read isoform counts from AnnData.
224
+
225
+ Returns denoised usage, direct observed/noisy usage, and a metadata dict.
226
+ """
227
+ if artifact.preprocessor is not None:
228
+ counts, n_found = artifact.preprocessor.extract_iso_counts(adata_iso)
229
+ else:
230
+ # Feature extraction can still be done from checkpoint metadata.
231
+ pre = IsoVAEPreprocessor(
232
+ genes=artifact.genes,
233
+ isoforms=artifact.isoforms,
234
+ isoform_gene=artifact.isoform_gene,
235
+ gene_groups=artifact.gene_groups,
236
+ scaler=None, # type: ignore[arg-type]
237
+ )
238
+ counts, n_found = pre.extract_iso_counts(adata_iso)
239
+
240
+ denoised, noisy_counts = denoise_usage_array(
241
+ artifact.model,
242
+ counts,
243
+ artifact.gene_groups,
244
+ device=artifact.device,
245
+ keep_rate=keep_rate,
246
+ batch_size=batch_size,
247
+ )
248
+ observed_usage, _, _ = counts_to_gene_usage(noisy_counts, artifact.isoform_gene)
249
+ den_df = pd.DataFrame(denoised, index=adata_iso.obs_names.astype(str), columns=artifact.isoforms)
250
+ obs_df = pd.DataFrame(observed_usage, index=adata_iso.obs_names.astype(str), columns=artifact.isoforms)
251
+ return den_df, obs_df, {"n_isoforms_found": int(n_found), "n_isoforms_total": len(artifact.isoforms)}
252
+
253
+
254
+ def usage_long_table(
255
+ usage: pd.DataFrame,
256
+ isoform_gene: Sequence[str],
257
+ obs_columns: Optional[pd.DataFrame] = None,
258
+ ) -> pd.DataFrame:
259
+ """Convert a cell-by-isoform usage matrix to a long table for plotting."""
260
+ long = usage.reset_index(names="cell").melt(
261
+ id_vars="cell", var_name="isoform", value_name="usage"
262
+ )
263
+ gene_map = pd.DataFrame({"isoform": usage.columns.astype(str), "gene": list(isoform_gene)})
264
+ long = long.merge(gene_map, on="isoform", how="left")
265
+ if obs_columns is not None:
266
+ meta = obs_columns.copy()
267
+ meta["cell"] = meta.index.astype(str)
268
+ long = long.merge(meta, on="cell", how="left")
269
+ return long
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List, Sequence
4
+
5
+ import numpy as np
6
+
7
+ Array = np.ndarray
8
+
9
+
10
+ def evaluate_usage_metrics(
11
+ pred_usage: Array,
12
+ y_usage: Array,
13
+ y_counts: Array,
14
+ groups: Sequence[Array],
15
+ eps: float = 1e-8,
16
+ ) -> Dict[str, float]:
17
+ """Evaluate predicted within-gene isoform usage against observed usage."""
18
+ ce_num = 0.0
19
+ ce_den = 0.0
20
+ top1_correct = 0
21
+ top2_correct = 0
22
+ total = 0
23
+ mae_list: List[float] = []
24
+ pearson_list: List[float] = []
25
+ cosine_list: List[float] = []
26
+
27
+ for idx in groups:
28
+ pred = pred_usage[:, idx]
29
+ truth = y_usage[:, idx]
30
+ counts = y_counts[:, idx]
31
+ cov = counts.sum(axis=1)
32
+ mask = cov > 0
33
+ if not mask.any():
34
+ continue
35
+
36
+ ce = -(truth[mask] * np.log(pred[mask] + eps)).sum(axis=1)
37
+ w = np.sqrt(cov[mask] + 1.0)
38
+ ce_num += float(np.sum(ce * w))
39
+ ce_den += float(np.sum(w))
40
+
41
+ obs_top = np.argmax(truth[mask], axis=1)
42
+ pred_order = np.argsort(pred[mask], axis=1)[:, ::-1]
43
+ top1_correct += int(np.sum(pred_order[:, 0] == obs_top))
44
+ top2_correct += int(np.sum([obs_top[i] in pred_order[i, :2] for i in range(len(obs_top))]))
45
+ total += int(mask.sum())
46
+ mae_list.append(float(np.mean(np.abs(pred[mask] - truth[mask]))))
47
+
48
+ if len(idx) >= 3:
49
+ p = pred[mask]
50
+ t = truth[mask]
51
+ dot = np.sum(p * t, axis=1)
52
+ denom = np.linalg.norm(p, axis=1) * np.linalg.norm(t, axis=1) + eps
53
+ cosine_list.extend((dot / denom).astype(float).tolist())
54
+ for i in range(p.shape[0]):
55
+ if np.std(p[i]) > eps and np.std(t[i]) > eps:
56
+ r = np.corrcoef(p[i], t[i])[0, 1]
57
+ if np.isfinite(r):
58
+ pearson_list.append(float(r))
59
+
60
+ return {
61
+ "weighted_ce": float(ce_num / max(ce_den, eps)),
62
+ "top1": float(top1_correct / max(total, 1)),
63
+ "top2": float(top2_correct / max(total, 1)),
64
+ "group_mae": float(np.mean(mae_list)) if mae_list else float("nan"),
65
+ "gene_cell_pearson": float(np.mean(pearson_list)) if pearson_list else float("nan"),
66
+ "gene_cell_cosine": float(np.mean(cosine_list)) if cosine_list else float("nan"),
67
+ "n_effective_gene_cell_pairs": int(total),
68
+ }
@@ -0,0 +1,175 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Sequence, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+
10
+ Array = np.ndarray
11
+
12
+
13
+ @dataclass
14
+ class IsoVAEConfig:
15
+ """Neural network configuration for IsoVAE."""
16
+
17
+ gene_hidden: Tuple[int, int] = (512, 256)
18
+ iso_hidden: Tuple[int, int] = (512, 256)
19
+ decoder_hidden: Tuple[int, int] = (256, 512)
20
+ latent_dim: int = 32
21
+ dropout: float = 0.20
22
+
23
+
24
+ class GaussianEncoder(nn.Module):
25
+ def __init__(
26
+ self,
27
+ in_dim: int,
28
+ hidden: Tuple[int, int],
29
+ latent_dim: int,
30
+ dropout: float,
31
+ ) -> None:
32
+ super().__init__()
33
+ h1, h2 = hidden
34
+ self.net = nn.Sequential(
35
+ nn.Linear(in_dim, h1),
36
+ nn.LayerNorm(h1),
37
+ nn.GELU(),
38
+ nn.Dropout(dropout),
39
+ nn.Linear(h1, h2),
40
+ nn.LayerNorm(h2),
41
+ nn.GELU(),
42
+ nn.Dropout(dropout),
43
+ )
44
+ self.mu = nn.Linear(h2, latent_dim)
45
+ self.logvar = nn.Linear(h2, latent_dim)
46
+
47
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
48
+ h = self.net(x)
49
+ return self.mu(h), torch.clamp(self.logvar(h), min=-10.0, max=6.0)
50
+
51
+
52
+ class IsoVAEModel(nn.Module):
53
+ """Multimodal hierarchical VAE for isoform-usage prediction and denoising."""
54
+
55
+ def __init__(
56
+ self,
57
+ n_gene_inputs: int,
58
+ n_isoforms: int,
59
+ n_gene_groups: int,
60
+ config: IsoVAEConfig,
61
+ ) -> None:
62
+ super().__init__()
63
+ self.n_gene_inputs = int(n_gene_inputs)
64
+ self.n_isoforms = int(n_isoforms)
65
+ self.n_gene_groups = int(n_gene_groups)
66
+ self.latent_dim = int(config.latent_dim)
67
+
68
+ self.gene_encoder = GaussianEncoder(
69
+ self.n_gene_inputs, config.gene_hidden, config.latent_dim, config.dropout
70
+ )
71
+ self.iso_encoder = GaussianEncoder(
72
+ self.n_isoforms + self.n_gene_groups,
73
+ config.iso_hidden,
74
+ config.latent_dim,
75
+ config.dropout,
76
+ )
77
+ d1, d2 = config.decoder_hidden
78
+ self.decoder = nn.Sequential(
79
+ nn.Linear(config.latent_dim, d1),
80
+ nn.LayerNorm(d1),
81
+ nn.GELU(),
82
+ nn.Dropout(config.dropout),
83
+ nn.Linear(d1, d2),
84
+ nn.LayerNorm(d2),
85
+ nn.GELU(),
86
+ nn.Dropout(config.dropout),
87
+ )
88
+ self.iso_head = nn.Linear(d2, self.n_isoforms)
89
+
90
+ @staticmethod
91
+ def reparameterize(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
92
+ if not torch.is_grad_enabled():
93
+ return mu
94
+ std = torch.exp(0.5 * logvar)
95
+ return mu + torch.randn_like(std) * std
96
+
97
+ def encode_gene(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
98
+ mu, logvar = self.gene_encoder(x)
99
+ z = self.reparameterize(mu, logvar) if self.training else mu
100
+ return z, mu, logvar
101
+
102
+ def encode_iso(self, iso_input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
103
+ mu, logvar = self.iso_encoder(iso_input)
104
+ z = self.reparameterize(mu, logvar) if self.training else mu
105
+ return z, mu, logvar
106
+
107
+ def decode_logits(self, z: torch.Tensor) -> torch.Tensor:
108
+ return self.iso_head(self.decoder(z))
109
+
110
+ def predict_from_gene(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
111
+ z, mu, logvar = self.encode_gene(x)
112
+ return self.decode_logits(z), mu, logvar
113
+
114
+ def denoise_from_iso(self, iso_input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
115
+ z, mu, logvar = self.encode_iso(iso_input)
116
+ return self.decode_logits(z), mu, logvar
117
+
118
+
119
+ def usage_from_counts_torch(
120
+ y_counts: torch.Tensor,
121
+ group_tensors: Sequence[torch.Tensor],
122
+ eps: float = 1e-8,
123
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
124
+ """Convert count tensor to usage tensor plus gene-level coverage features."""
125
+ usage = torch.zeros_like(y_counts)
126
+ covs: List[torch.Tensor] = []
127
+ for idx in group_tensors:
128
+ counts_g = y_counts.index_select(1, idx)
129
+ cov = counts_g.sum(dim=1, keepdim=True)
130
+ usage_g = counts_g / (cov + eps)
131
+ usage[:, idx] = torch.where(cov > 0, usage_g, torch.zeros_like(usage_g))
132
+ covs.append(cov)
133
+ return usage, torch.cat(covs, dim=1)
134
+
135
+
136
+ def iso_encoder_input_from_counts(
137
+ y_counts: torch.Tensor,
138
+ group_tensors: Sequence[torch.Tensor],
139
+ ) -> torch.Tensor:
140
+ """Build long-read encoder input: isoform usage + standardized log gene coverage."""
141
+ usage, cov = usage_from_counts_torch(y_counts, group_tensors)
142
+ cov_feat = torch.log1p(cov)
143
+ cov_feat = (cov_feat - cov_feat.mean(dim=1, keepdim=True)) / (
144
+ cov_feat.std(dim=1, keepdim=True) + 1e-6
145
+ )
146
+ return torch.cat([usage, cov_feat], dim=1)
147
+
148
+
149
+ def logits_to_usage(
150
+ logits: torch.Tensor,
151
+ groups: Sequence[Array],
152
+ ) -> Array:
153
+ """Apply gene-wise softmax to logits and return isoform-usage proportions."""
154
+ out = np.zeros(tuple(logits.shape), dtype=np.float32)
155
+ for idx_np in groups:
156
+ idx = torch.as_tensor(idx_np, dtype=torch.long, device=logits.device)
157
+ out[:, idx_np] = torch.softmax(logits.index_select(1, idx), dim=1).detach().cpu().numpy()
158
+ return out
159
+
160
+
161
+ def make_config_from_checkpoint(config_dict: dict, state_dict: Optional[dict] = None) -> IsoVAEConfig:
162
+ """Create a clean IsoVAEConfig from older checkpoints.
163
+
164
+ Historical checkpoints may contain extra keys from attention experiments.
165
+ They are ignored here so that the public package loads the final model.
166
+ """
167
+ allowed = {"gene_hidden", "iso_hidden", "decoder_hidden", "latent_dim", "dropout"}
168
+ clean = {k: v for k, v in (config_dict or {}).items() if k in allowed}
169
+ if "gene_hidden" in clean:
170
+ clean["gene_hidden"] = tuple(clean["gene_hidden"])
171
+ if "iso_hidden" in clean:
172
+ clean["iso_hidden"] = tuple(clean["iso_hidden"])
173
+ if "decoder_hidden" in clean:
174
+ clean["decoder_hidden"] = tuple(clean["decoder_hidden"])
175
+ return IsoVAEConfig(**clean)
@@ -0,0 +1,20 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ def set_seed(seed: int = 42) -> None:
10
+ """Set random seeds for reproducible NumPy/PyTorch experiments."""
11
+ random.seed(seed)
12
+ np.random.seed(seed)
13
+ torch.manual_seed(seed)
14
+ if torch.cuda.is_available():
15
+ torch.cuda.manual_seed_all(seed)
16
+
17
+
18
+ def select_device(prefer_cuda: bool = True) -> str:
19
+ """Return ``cuda`` when available and requested, otherwise ``cpu``."""
20
+ return "cuda" if prefer_cuda and torch.cuda.is_available() else "cpu"
@@ -0,0 +1,62 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Optional, Sequence
5
+
6
+ import matplotlib.pyplot as plt
7
+ import pandas as pd
8
+ import seaborn as sns
9
+
10
+
11
+ def plot_gene_usage_boxplot(
12
+ usage_long: pd.DataFrame,
13
+ gene: str,
14
+ groupby: Optional[str] = None,
15
+ ax: Optional[plt.Axes] = None,
16
+ ) -> plt.Axes:
17
+ """Plot isoform-usage distributions for one gene.
18
+
19
+ ``usage_long`` should contain at least columns: ``gene``, ``isoform``,
20
+ ``usage`` and optionally a grouping column such as stage or cell type.
21
+ """
22
+ data = usage_long.loc[usage_long["gene"].astype(str) == str(gene)].copy()
23
+ if data.empty:
24
+ raise ValueError(f"No rows found for gene={gene!r}.")
25
+ if ax is None:
26
+ _, ax = plt.subplots(figsize=(max(5, data["isoform"].nunique() * 1.2), 4))
27
+ if groupby and groupby in data.columns:
28
+ sns.boxplot(data=data, x=groupby, y="usage", hue="isoform", ax=ax, fliersize=0.5)
29
+ ax.tick_params(axis="x", rotation=45)
30
+ else:
31
+ sns.violinplot(data=data, x="isoform", y="usage", ax=ax, cut=0, inner="box")
32
+ ax.tick_params(axis="x", rotation=45)
33
+ ax.set_title(f"{gene} isoform usage")
34
+ ax.set_xlabel(groupby if groupby else "Isoform")
35
+ ax.set_ylabel("Isoform usage")
36
+ return ax
37
+
38
+
39
+ def plot_usage_heatmap(
40
+ usage: pd.DataFrame,
41
+ isoforms: Optional[Sequence[str]] = None,
42
+ max_cells: int = 200,
43
+ ax: Optional[plt.Axes] = None,
44
+ ) -> plt.Axes:
45
+ """Plot a compact heatmap of cell-by-isoform usage values."""
46
+ mat = usage.loc[:, list(isoforms)] if isoforms is not None else usage
47
+ if mat.shape[0] > max_cells:
48
+ mat = mat.iloc[:max_cells]
49
+ if ax is None:
50
+ _, ax = plt.subplots(figsize=(8, 5))
51
+ sns.heatmap(mat, cmap="viridis", xticklabels=False, yticklabels=False, ax=ax)
52
+ ax.set_xlabel("Isoforms")
53
+ ax.set_ylabel("Cells")
54
+ return ax
55
+
56
+
57
+ def savefig(path: str | Path, dpi: int = 300) -> None:
58
+ """Save current matplotlib figure with tight layout."""
59
+ path = Path(path)
60
+ path.parent.mkdir(parents=True, exist_ok=True)
61
+ plt.tight_layout()
62
+ plt.savefig(path, dpi=dpi, bbox_inches="tight")
@@ -0,0 +1,125 @@
1
+ Metadata-Version: 2.4
2
+ Name: isovae
3
+ Version: 0.1.0
4
+ Summary: IsoVAE: isoform-usage prediction and long-read isoform-usage denoising for single-cell RNA-seq
5
+ Author: IsoVAE developers
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/your-username/IsoVAE
8
+ Project-URL: Documentation, https://your-username.github.io/IsoVAE/
9
+ Project-URL: Repository, https://github.com/your-username/IsoVAE
10
+ Project-URL: Issues, https://github.com/your-username/IsoVAE/issues
11
+ Keywords: single-cell RNA-seq,isoform usage,long-read RNA-seq,variational autoencoder,denoising
12
+ Requires-Python: >=3.10
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: numpy>=1.23
16
+ Requires-Dist: pandas>=1.5
17
+ Requires-Dist: scipy>=1.9
18
+ Requires-Dist: scikit-learn>=1.2
19
+ Requires-Dist: anndata>=0.9
20
+ Requires-Dist: torch>=2.0
21
+ Requires-Dist: matplotlib>=3.6
22
+ Requires-Dist: seaborn>=0.12
23
+ Provides-Extra: docs
24
+ Requires-Dist: mkdocs>=1.5; extra == "docs"
25
+ Requires-Dist: mkdocs-material>=9.5; extra == "docs"
26
+ Requires-Dist: mkdocstrings[python]>=0.24; extra == "docs"
27
+ Provides-Extra: dev
28
+ Requires-Dist: pytest; extra == "dev"
29
+ Requires-Dist: ruff; extra == "dev"
30
+ Dynamic: license-file
31
+
32
+ # IsoVAE
33
+
34
+ IsoVAE is a Python package for single-cell isoform-usage analysis. It supports:
35
+
36
+ 1. **Isoform-usage prediction** from short-read single-cell gene-expression profiles.
37
+ 2. **Long-read isoform-usage denoising** from sparse long-read isoform count matrices.
38
+
39
+ IsoVAE models **within-gene isoform usage proportions**, not absolute transcript abundance.
40
+
41
+ ## Installation
42
+
43
+ ```bash
44
+ pip install isovae
45
+ ```
46
+
47
+ For local development:
48
+
49
+ ```bash
50
+ git clone https://github.com/your-username/IsoVAE.git
51
+ cd IsoVAE
52
+ pip install -e .
53
+ ```
54
+
55
+ ## Quick start
56
+
57
+ ```python
58
+ import scanpy as sc
59
+ from isovae import (
60
+ load_artifact,
61
+ reconstruct_preprocessor_from_training_data,
62
+ predict_isoform_usage,
63
+ denoise_isoform_usage,
64
+ )
65
+
66
+ model_path = "path/to/vae_xda_model.pt"
67
+
68
+ gene_train = sc.read("path/to/training_gene_matrix.h5ad")
69
+ iso_train = sc.read("path/to/training_isoform_matrix.h5ad")
70
+
71
+ preprocessor = reconstruct_preprocessor_from_training_data(
72
+ model_path,
73
+ adata_gene_train=gene_train,
74
+ adata_iso_train=iso_train,
75
+ seed=42,
76
+ )
77
+
78
+ artifact = load_artifact(model_path, preprocessor=preprocessor, device="cpu")
79
+
80
+ # Predict isoform usage from short-read data.
81
+ gene_query = sc.read("path/to/query_gene_matrix.h5ad")
82
+ pred_usage, pred_meta = predict_isoform_usage(artifact, gene_query)
83
+ pred_usage.to_csv("predicted_isoform_usage.csv")
84
+
85
+ # Denoise long-read isoform usage.
86
+ iso_query = sc.read("path/to/query_isoform_matrix.h5ad")
87
+ denoised_usage, noisy_usage, denoise_meta = denoise_isoform_usage(artifact, iso_query)
88
+ denoised_usage.to_csv("denoised_isoform_usage.csv")
89
+ ```
90
+
91
+ ## Documentation
92
+
93
+ The documentation source is in `docs/` and can be built with MkDocs:
94
+
95
+ ```bash
96
+ pip install -e ".[docs]"
97
+ mkdocs serve
98
+ ```
99
+
100
+ To deploy to GitHub Pages:
101
+
102
+ ```bash
103
+ mkdocs gh-deploy
104
+ ```
105
+
106
+ See `docs/deployment.md` for deployment instructions for GitHub Pages, Read the Docs, Netlify and Vercel.
107
+
108
+ ## Repository layout
109
+
110
+ ```text
111
+ .
112
+ ├── src/isovae/ # Python package
113
+ ├── docs/ # Documentation source
114
+ ├── mkdocs.yml # Documentation configuration
115
+ ├── pyproject.toml # Package metadata
116
+ ├── requirements.txt
117
+ ├── LICENSE
118
+ └── README.md
119
+ ```
120
+
121
+ Large data files, AnnData objects, model checkpoints and manuscript outputs are not included in the package.
122
+
123
+ ## Citation
124
+
125
+ If you use IsoVAE, please cite the accompanying manuscript after publication.
@@ -0,0 +1,15 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ src/isovae/__init__.py
5
+ src/isovae/data.py
6
+ src/isovae/inference.py
7
+ src/isovae/metrics.py
8
+ src/isovae/model.py
9
+ src/isovae/utils.py
10
+ src/isovae/viz.py
11
+ src/isovae.egg-info/PKG-INFO
12
+ src/isovae.egg-info/SOURCES.txt
13
+ src/isovae.egg-info/dependency_links.txt
14
+ src/isovae.egg-info/requires.txt
15
+ src/isovae.egg-info/top_level.txt
@@ -0,0 +1,17 @@
1
+ numpy>=1.23
2
+ pandas>=1.5
3
+ scipy>=1.9
4
+ scikit-learn>=1.2
5
+ anndata>=0.9
6
+ torch>=2.0
7
+ matplotlib>=3.6
8
+ seaborn>=0.12
9
+
10
+ [dev]
11
+ pytest
12
+ ruff
13
+
14
+ [docs]
15
+ mkdocs>=1.5
16
+ mkdocs-material>=9.5
17
+ mkdocstrings[python]>=0.24
@@ -0,0 +1 @@
1
+ isovae