univi 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- univi/__init__.py +120 -0
- univi/__main__.py +5 -0
- univi/cli.py +60 -0
- univi/config.py +340 -0
- univi/data.py +345 -0
- univi/diagnostics.py +130 -0
- univi/evaluation.py +632 -0
- univi/hyperparam_optimization/__init__.py +17 -0
- univi/hyperparam_optimization/common.py +339 -0
- univi/hyperparam_optimization/run_adt_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_atac_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_citeseq_hparam_search.py +137 -0
- univi/hyperparam_optimization/run_multiome_hparam_search.py +145 -0
- univi/hyperparam_optimization/run_rna_hparam_search.py +111 -0
- univi/hyperparam_optimization/run_teaseq_hparam_search.py +146 -0
- univi/interpretability.py +399 -0
- univi/matching.py +394 -0
- univi/models/__init__.py +8 -0
- univi/models/decoders.py +249 -0
- univi/models/encoders.py +848 -0
- univi/models/mlp.py +36 -0
- univi/models/tokenizers.py +376 -0
- univi/models/transformer.py +249 -0
- univi/models/univi.py +1284 -0
- univi/objectives.py +46 -0
- univi/pipeline.py +194 -0
- univi/plotting.py +126 -0
- univi/trainer.py +478 -0
- univi/utils/__init__.py +5 -0
- univi/utils/io.py +621 -0
- univi/utils/logging.py +16 -0
- univi/utils/seed.py +18 -0
- univi/utils/stats.py +23 -0
- univi/utils/torch_utils.py +23 -0
- univi-0.3.4.dist-info/METADATA +908 -0
- univi-0.3.4.dist-info/RECORD +40 -0
- univi-0.3.4.dist-info/WHEEL +5 -0
- univi-0.3.4.dist-info/entry_points.txt +2 -0
- univi-0.3.4.dist-info/licenses/LICENSE +21 -0
- univi-0.3.4.dist-info/top_level.txt +1 -0
univi/matching.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
# univi/matching.py
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
import numpy as np
|
|
5
|
+
from sklearn.metrics import pairwise_distances
|
|
6
|
+
from sklearn.neighbors import NearestNeighbors
|
|
7
|
+
from scipy.optimize import linear_sum_assignment
|
|
8
|
+
from typing import Optional, Dict
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _subsample_indices(n: int, max_cells: int, rng: np.random.Generator) -> np.ndarray:
|
|
12
|
+
"""
|
|
13
|
+
Helper to subsample up to max_cells indices from range(n) without replacement.
|
|
14
|
+
"""
|
|
15
|
+
idx_full = np.arange(n)
|
|
16
|
+
if n <= max_cells:
|
|
17
|
+
return idx_full
|
|
18
|
+
return rng.choice(idx_full, size=max_cells, replace=False)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# ---------------------------------------------------------------------------
|
|
22
|
+
# 1. Basic bipartite matching (Hungarian) in a shared embedding
|
|
23
|
+
# ---------------------------------------------------------------------------
|
|
24
|
+
|
|
25
|
+
def bipartite_match_adata(
|
|
26
|
+
adata_A,
|
|
27
|
+
adata_B,
|
|
28
|
+
emb_key: str = "X_pca",
|
|
29
|
+
metric: str = "euclidean",
|
|
30
|
+
max_cells: int = 20000,
|
|
31
|
+
random_state: int = 0,
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Bipartite matching between cells in adata_A and adata_B based on a shared embedding.
|
|
35
|
+
|
|
36
|
+
This is the basic "Hungarian in latent space" matcher. It assumes that both
|
|
37
|
+
adata_A and adata_B have a *comparable* embedding in .obsm[emb_key], e.g.:
|
|
38
|
+
|
|
39
|
+
- both are in the same PCA space, or
|
|
40
|
+
- both are in a shared latent space (CCA, UniVI encoder, etc.)
|
|
41
|
+
"""
|
|
42
|
+
rng = np.random.default_rng(random_state)
|
|
43
|
+
|
|
44
|
+
XA = np.asarray(adata_A.obsm[emb_key])
|
|
45
|
+
XB = np.asarray(adata_B.obsm[emb_key])
|
|
46
|
+
|
|
47
|
+
na, nb = XA.shape[0], XB.shape[0]
|
|
48
|
+
n = min(na, nb, max_cells)
|
|
49
|
+
|
|
50
|
+
idx_A = _subsample_indices(na, n, rng)
|
|
51
|
+
idx_B = _subsample_indices(nb, n, rng)
|
|
52
|
+
|
|
53
|
+
XA_sub = XA[idx_A]
|
|
54
|
+
XB_sub = XB[idx_B]
|
|
55
|
+
|
|
56
|
+
# cost matrix
|
|
57
|
+
D = pairwise_distances(XA_sub, XB_sub, metric=metric)
|
|
58
|
+
|
|
59
|
+
# Hungarian algorithm (min-cost)
|
|
60
|
+
row_ind, col_ind = linear_sum_assignment(D)
|
|
61
|
+
|
|
62
|
+
matched_A = idx_A[row_ind]
|
|
63
|
+
matched_B = idx_B[col_ind]
|
|
64
|
+
|
|
65
|
+
return matched_A, matched_B
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# ---------------------------------------------------------------------------
|
|
69
|
+
# 2. Stratified bipartite matching (per group / cell type)
|
|
70
|
+
# ---------------------------------------------------------------------------
|
|
71
|
+
|
|
72
|
+
def stratified_bipartite_match_adata(
|
|
73
|
+
adata_A,
|
|
74
|
+
adata_B,
|
|
75
|
+
group_key_A: str,
|
|
76
|
+
group_key_B: Optional[str] = None,
|
|
77
|
+
group_map: Optional[Dict] = None,
|
|
78
|
+
emb_key: str = "X_pca",
|
|
79
|
+
metric: str = "euclidean",
|
|
80
|
+
max_cells_per_group: int = 20000,
|
|
81
|
+
random_state: int = 0,
|
|
82
|
+
shuffle: bool = True,
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Per-group (e.g. per-celltype) bipartite matching in a shared embedding.
|
|
86
|
+
|
|
87
|
+
This wraps `bipartite_match_adata` but runs it separately within each group,
|
|
88
|
+
then concatenates the matches.
|
|
89
|
+
"""
|
|
90
|
+
rng = np.random.default_rng(random_state)
|
|
91
|
+
|
|
92
|
+
if group_key_B is None:
|
|
93
|
+
group_key_B = group_key_A
|
|
94
|
+
|
|
95
|
+
if group_key_A not in adata_A.obs.columns:
|
|
96
|
+
raise KeyError(f"{group_key_A!r} not found in adata_A.obs")
|
|
97
|
+
if group_key_B not in adata_B.obs.columns:
|
|
98
|
+
raise KeyError(f"{group_key_B!r} not found in adata_B.obs")
|
|
99
|
+
|
|
100
|
+
labels_A = adata_A.obs[group_key_A].astype(str).to_numpy()
|
|
101
|
+
labels_B = adata_B.obs[group_key_B].astype(str).to_numpy()
|
|
102
|
+
|
|
103
|
+
unique_A = np.unique(labels_A)
|
|
104
|
+
|
|
105
|
+
all_matched_A = []
|
|
106
|
+
all_matched_B = []
|
|
107
|
+
group_counts: Dict[str, int] = {}
|
|
108
|
+
|
|
109
|
+
for gA in unique_A:
|
|
110
|
+
# Determine which label in B we should match to
|
|
111
|
+
if group_map is not None:
|
|
112
|
+
gB = group_map.get(gA, None)
|
|
113
|
+
if gB is None:
|
|
114
|
+
# group not mapped; skip
|
|
115
|
+
continue
|
|
116
|
+
else:
|
|
117
|
+
gB = gA
|
|
118
|
+
|
|
119
|
+
idx_A_g = np.where(labels_A == gA)[0]
|
|
120
|
+
idx_B_g = np.where(labels_B == gB)[0]
|
|
121
|
+
|
|
122
|
+
if (idx_A_g.size == 0) or (idx_B_g.size == 0):
|
|
123
|
+
continue
|
|
124
|
+
|
|
125
|
+
# Build small views
|
|
126
|
+
adata_A_g = adata_A[idx_A_g]
|
|
127
|
+
adata_B_g = adata_B[idx_B_g]
|
|
128
|
+
|
|
129
|
+
# n per group for bipartite matching
|
|
130
|
+
n_grp = min(idx_A_g.size, idx_B_g.size, max_cells_per_group)
|
|
131
|
+
if n_grp == 0:
|
|
132
|
+
continue
|
|
133
|
+
|
|
134
|
+
mA_local, mB_local = bipartite_match_adata(
|
|
135
|
+
adata_A_g,
|
|
136
|
+
adata_B_g,
|
|
137
|
+
emb_key=emb_key,
|
|
138
|
+
metric=metric,
|
|
139
|
+
max_cells=n_grp,
|
|
140
|
+
random_state=random_state,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if mA_local.size == 0:
|
|
144
|
+
continue
|
|
145
|
+
|
|
146
|
+
mA = idx_A_g[mA_local]
|
|
147
|
+
mB = idx_B_g[mB_local]
|
|
148
|
+
|
|
149
|
+
all_matched_A.append(mA)
|
|
150
|
+
all_matched_B.append(mB)
|
|
151
|
+
group_counts[gA] = mA.size
|
|
152
|
+
|
|
153
|
+
if not all_matched_A:
|
|
154
|
+
raise RuntimeError("No stratified matches were found for any group.")
|
|
155
|
+
|
|
156
|
+
matched_A = np.concatenate(all_matched_A)
|
|
157
|
+
matched_B = np.concatenate(all_matched_B)
|
|
158
|
+
|
|
159
|
+
if shuffle:
|
|
160
|
+
perm = rng.permutation(matched_A.size)
|
|
161
|
+
matched_A = matched_A[perm]
|
|
162
|
+
matched_B = matched_B[perm]
|
|
163
|
+
|
|
164
|
+
return matched_A, matched_B, group_counts
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# ---------------------------------------------------------------------------
|
|
168
|
+
# 3. Mutual Nearest Neighbor (MNN) anchors
|
|
169
|
+
# ---------------------------------------------------------------------------
|
|
170
|
+
|
|
171
|
+
def mnn_anchors_adata(
|
|
172
|
+
adata_A,
|
|
173
|
+
adata_B,
|
|
174
|
+
emb_key: str = "X_pca",
|
|
175
|
+
k: int = 20,
|
|
176
|
+
max_cells: int = 20000,
|
|
177
|
+
random_state: int = 0,
|
|
178
|
+
):
|
|
179
|
+
"""
|
|
180
|
+
Mutual Nearest Neighbor (MNN) anchors between adata_A and adata_B.
|
|
181
|
+
"""
|
|
182
|
+
rng = np.random.default_rng(random_state)
|
|
183
|
+
|
|
184
|
+
XA = np.asarray(adata_A.obsm[emb_key])
|
|
185
|
+
XB = np.asarray(adata_B.obsm[emb_key])
|
|
186
|
+
|
|
187
|
+
na, nb = XA.shape[0], XB.shape[0]
|
|
188
|
+
|
|
189
|
+
idx_A = _subsample_indices(na, max_cells, rng)
|
|
190
|
+
idx_B = _subsample_indices(nb, max_cells, rng)
|
|
191
|
+
|
|
192
|
+
XA_sub = XA[idx_A]
|
|
193
|
+
XB_sub = XB[idx_B]
|
|
194
|
+
|
|
195
|
+
k_A = min(k, idx_B.size)
|
|
196
|
+
k_B = min(k, idx_A.size)
|
|
197
|
+
|
|
198
|
+
# A -> B neighbors
|
|
199
|
+
nn_B = NearestNeighbors(n_neighbors=k_A)
|
|
200
|
+
nn_B.fit(XB_sub)
|
|
201
|
+
dist_A2B, ind_A2B = nn_B.kneighbors(XA_sub, return_distance=True)
|
|
202
|
+
|
|
203
|
+
# B -> A neighbors
|
|
204
|
+
nn_A = NearestNeighbors(n_neighbors=k_B)
|
|
205
|
+
nn_A.fit(XA_sub)
|
|
206
|
+
dist_B2A, ind_B2A = nn_A.kneighbors(XB_sub, return_distance=True)
|
|
207
|
+
|
|
208
|
+
# Mutual neighbors
|
|
209
|
+
anchors_A_list = []
|
|
210
|
+
anchors_B_list = []
|
|
211
|
+
|
|
212
|
+
# For quick membership tests: for each j, set of neighbors in A
|
|
213
|
+
neighbors_B2A = [set(ind_B2A[j]) for j in range(idx_B.size)]
|
|
214
|
+
|
|
215
|
+
for i in range(idx_A.size):
|
|
216
|
+
for j_local in ind_A2B[i]:
|
|
217
|
+
# Check mutuality
|
|
218
|
+
if i in neighbors_B2A[j_local]:
|
|
219
|
+
anchors_A_list.append(idx_A[i])
|
|
220
|
+
anchors_B_list.append(idx_B[j_local])
|
|
221
|
+
|
|
222
|
+
if not anchors_A_list:
|
|
223
|
+
warnings.warn("mnn_anchors_adata: no mutual nearest neighbors found.")
|
|
224
|
+
return np.array([], dtype=int), np.array([], dtype=int)
|
|
225
|
+
|
|
226
|
+
anchors_A = np.array(anchors_A_list, dtype=int)
|
|
227
|
+
anchors_B = np.array(anchors_B_list, dtype=int)
|
|
228
|
+
|
|
229
|
+
return anchors_A, anchors_B
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
# ---------------------------------------------------------------------------
|
|
233
|
+
# 4. Cluster / cell-type centroid matching (for building group maps)
|
|
234
|
+
# ---------------------------------------------------------------------------
|
|
235
|
+
|
|
236
|
+
def cluster_centroid_matching_adata(
|
|
237
|
+
adata_A,
|
|
238
|
+
adata_B,
|
|
239
|
+
group_key_A: str,
|
|
240
|
+
group_key_B: Optional[str] = None,
|
|
241
|
+
emb_key: str = "X_pca",
|
|
242
|
+
metric: str = "euclidean",
|
|
243
|
+
):
|
|
244
|
+
"""
|
|
245
|
+
Match cluster / cell-type centroids across datasets via Hungarian.
|
|
246
|
+
"""
|
|
247
|
+
if group_key_B is None:
|
|
248
|
+
group_key_B = group_key_A
|
|
249
|
+
|
|
250
|
+
if group_key_A not in adata_A.obs.columns:
|
|
251
|
+
raise KeyError(f"{group_key_A!r} not found in adata_A.obs")
|
|
252
|
+
if group_key_B not in adata_B.obs.columns:
|
|
253
|
+
raise KeyError(f"{group_key_B!r} not found in adata_B.obs")
|
|
254
|
+
|
|
255
|
+
labels_A = adata_A.obs[group_key_A].astype(str).to_numpy()
|
|
256
|
+
labels_B = adata_B.obs[group_key_B].astype(str).to_numpy()
|
|
257
|
+
|
|
258
|
+
XA = np.asarray(adata_A.obsm[emb_key])
|
|
259
|
+
XB = np.asarray(adata_B.obsm[emb_key])
|
|
260
|
+
|
|
261
|
+
groups_A = np.unique(labels_A)
|
|
262
|
+
groups_B = np.unique(labels_B)
|
|
263
|
+
|
|
264
|
+
# Compute centroids
|
|
265
|
+
centroids_A = []
|
|
266
|
+
for g in groups_A:
|
|
267
|
+
idx = np.where(labels_A == g)[0]
|
|
268
|
+
centroids_A.append(XA[idx].mean(axis=0))
|
|
269
|
+
centroids_A = np.vstack(centroids_A)
|
|
270
|
+
|
|
271
|
+
centroids_B = []
|
|
272
|
+
for g in groups_B:
|
|
273
|
+
idx = np.where(labels_B == g)[0]
|
|
274
|
+
centroids_B.append(XB[idx].mean(axis=0))
|
|
275
|
+
centroids_B = np.vstack(centroids_B)
|
|
276
|
+
|
|
277
|
+
# Cost between centroids
|
|
278
|
+
D = pairwise_distances(centroids_A, centroids_B, metric=metric)
|
|
279
|
+
row_ind, col_ind = linear_sum_assignment(D)
|
|
280
|
+
|
|
281
|
+
group_map: Dict[str, str] = {}
|
|
282
|
+
for i, j in zip(row_ind, col_ind):
|
|
283
|
+
gA = groups_A[i]
|
|
284
|
+
gB = groups_B[j]
|
|
285
|
+
group_map[gA] = gB
|
|
286
|
+
|
|
287
|
+
return group_map
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
# ---------------------------------------------------------------------------
|
|
291
|
+
# 5. Gromov–Wasserstein OT-based anchors (optional; requires POT)
|
|
292
|
+
# ---------------------------------------------------------------------------
|
|
293
|
+
|
|
294
|
+
def gw_ot_anchors_adata(
|
|
295
|
+
adata_A,
|
|
296
|
+
adata_B,
|
|
297
|
+
emb_key: str = "X_pca",
|
|
298
|
+
max_cells: int = 3000,
|
|
299
|
+
random_state: int = 0,
|
|
300
|
+
normalize_distances: bool = True,
|
|
301
|
+
):
|
|
302
|
+
"""
|
|
303
|
+
Geometry-aware anchors via Gromov–Wasserstein optimal transport.
|
|
304
|
+
"""
|
|
305
|
+
try:
|
|
306
|
+
import ot # type: ignore
|
|
307
|
+
except ImportError as e:
|
|
308
|
+
raise ImportError(
|
|
309
|
+
"gw_ot_anchors_adata requires the 'pot' package. "
|
|
310
|
+
"Install with: pip install pot"
|
|
311
|
+
) from e
|
|
312
|
+
|
|
313
|
+
rng = np.random.default_rng(random_state)
|
|
314
|
+
|
|
315
|
+
XA = np.asarray(adata_A.obsm[emb_key])
|
|
316
|
+
XB = np.asarray(adata_B.obsm[emb_key])
|
|
317
|
+
|
|
318
|
+
na, nb = XA.shape[0], XB.shape[0]
|
|
319
|
+
idx_A = _subsample_indices(na, max_cells, rng)
|
|
320
|
+
idx_B = _subsample_indices(nb, max_cells, rng)
|
|
321
|
+
|
|
322
|
+
XA_sub = XA[idx_A]
|
|
323
|
+
XB_sub = XB[idx_B]
|
|
324
|
+
|
|
325
|
+
# Distance matrices within each dataset
|
|
326
|
+
DA = pairwise_distances(XA_sub, XA_sub, metric="euclidean")
|
|
327
|
+
DB = pairwise_distances(XB_sub, XB_sub, metric="euclidean")
|
|
328
|
+
|
|
329
|
+
if normalize_distances:
|
|
330
|
+
if DA.max() > 0:
|
|
331
|
+
DA = DA / DA.max()
|
|
332
|
+
if DB.max() > 0:
|
|
333
|
+
DB = DB / DB.max()
|
|
334
|
+
|
|
335
|
+
# Uniform weights
|
|
336
|
+
p = np.ones(DA.shape[0]) / DA.shape[0]
|
|
337
|
+
q = np.ones(DB.shape[0]) / DB.shape[0]
|
|
338
|
+
|
|
339
|
+
# Compute GW coupling
|
|
340
|
+
T = ot.gromov.gromov_wasserstein(
|
|
341
|
+
DA, DB, p, q, loss_fun="square_loss", verbose=False
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# For each i in A, pick the j in B with maximum coupling mass
|
|
345
|
+
anchors_A_list = []
|
|
346
|
+
anchors_B_list = []
|
|
347
|
+
for i in range(T.shape[0]):
|
|
348
|
+
j = int(np.argmax(T[i]))
|
|
349
|
+
anchors_A_list.append(idx_A[i])
|
|
350
|
+
anchors_B_list.append(idx_B[j])
|
|
351
|
+
|
|
352
|
+
anchors_A = np.array(anchors_A_list, dtype=int)
|
|
353
|
+
anchors_B = np.array(anchors_B_list, dtype=int)
|
|
354
|
+
|
|
355
|
+
return anchors_A, anchors_B
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
# ---------------------------------------------------------------------------
|
|
359
|
+
# 6. Group latent statistics (for distribution-level alignment)
|
|
360
|
+
# ---------------------------------------------------------------------------
|
|
361
|
+
|
|
362
|
+
def group_latent_stats_adata(
|
|
363
|
+
adata,
|
|
364
|
+
group_key: str,
|
|
365
|
+
emb_key: str = "X_pca",
|
|
366
|
+
):
|
|
367
|
+
"""
|
|
368
|
+
Compute simple group-wise latent statistics (mean, covariance) to
|
|
369
|
+
support distribution-level alignment strategies.
|
|
370
|
+
"""
|
|
371
|
+
if group_key not in adata.obs.columns:
|
|
372
|
+
raise KeyError(f"{group_key!r} not found in adata.obs")
|
|
373
|
+
|
|
374
|
+
labels = adata.obs[group_key].astype(str).to_numpy()
|
|
375
|
+
X = np.asarray(adata.obsm[emb_key])
|
|
376
|
+
|
|
377
|
+
groups = np.unique(labels)
|
|
378
|
+
stats = {}
|
|
379
|
+
|
|
380
|
+
for g in groups:
|
|
381
|
+
idx = np.where(labels == g)[0]
|
|
382
|
+
if idx.size == 0:
|
|
383
|
+
continue
|
|
384
|
+
Xg = X[idx]
|
|
385
|
+
mu = Xg.mean(axis=0)
|
|
386
|
+
# rowvar=False so that columns are variables
|
|
387
|
+
cov = np.cov(Xg, rowvar=False)
|
|
388
|
+
stats[g] = {
|
|
389
|
+
"mean": mu,
|
|
390
|
+
"cov": cov,
|
|
391
|
+
"n": idx.size,
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
return stats
|
univi/models/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
# univi/models/__init__.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from .univi import UniVIMultiModalVAE
|
|
5
|
+
from .transformer import TransformerEncoder
|
|
6
|
+
from .tokenizers import build_tokenizer
|
|
7
|
+
|
|
8
|
+
__all__ = ["UniVIMultiModalVAE", "TransformerEncoder", "build_tokenizer"]
|
univi/models/decoders.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
# univi/models/decoders.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Dict, List
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
|
|
12
|
+
from .mlp import build_mlp
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class DecoderConfig:
|
|
17
|
+
"""Generic configuration for feed-forward decoders."""
|
|
18
|
+
output_dim: int
|
|
19
|
+
hidden_dims: List[int]
|
|
20
|
+
dropout: float = 0.0
|
|
21
|
+
batchnorm: bool = False
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class GaussianDecoder(nn.Module):
|
|
25
|
+
"""z -> mean reconstruction (use with MSE/Gaussian losses)."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, cfg: DecoderConfig, latent_dim: int):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.cfg = cfg
|
|
30
|
+
self.net = build_mlp(
|
|
31
|
+
in_dim=latent_dim,
|
|
32
|
+
hidden_dims=cfg.hidden_dims,
|
|
33
|
+
out_dim=cfg.output_dim,
|
|
34
|
+
dropout=cfg.dropout,
|
|
35
|
+
batchnorm=cfg.batchnorm,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
|
39
|
+
return self.net(z)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class GaussianDiagDecoder(nn.Module):
|
|
43
|
+
"""z -> {'mean','logvar'} for full diagonal Gaussian likelihoods."""
|
|
44
|
+
|
|
45
|
+
def __init__(self, cfg: DecoderConfig, latent_dim: int):
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.cfg = cfg
|
|
48
|
+
self.backbone = build_mlp(
|
|
49
|
+
in_dim=latent_dim,
|
|
50
|
+
hidden_dims=cfg.hidden_dims,
|
|
51
|
+
out_dim=2 * cfg.output_dim,
|
|
52
|
+
dropout=cfg.dropout,
|
|
53
|
+
batchnorm=cfg.batchnorm,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
57
|
+
out = self.backbone(z)
|
|
58
|
+
mean, logvar = out.chunk(2, dim=-1)
|
|
59
|
+
return {"mean": mean, "logvar": logvar}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class BernoulliDecoder(nn.Module):
|
|
63
|
+
"""z -> {'logits'} for Bernoulli likelihoods."""
|
|
64
|
+
|
|
65
|
+
def __init__(self, cfg: DecoderConfig, latent_dim: int):
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.cfg = cfg
|
|
68
|
+
self.net = build_mlp(
|
|
69
|
+
in_dim=latent_dim,
|
|
70
|
+
hidden_dims=cfg.hidden_dims,
|
|
71
|
+
out_dim=cfg.output_dim,
|
|
72
|
+
dropout=cfg.dropout,
|
|
73
|
+
batchnorm=cfg.batchnorm,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
77
|
+
logits = self.net(z)
|
|
78
|
+
return {"logits": logits}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class PoissonDecoder(nn.Module):
|
|
82
|
+
"""z -> {'log_rate','rate'} for Poisson likelihoods."""
|
|
83
|
+
|
|
84
|
+
def __init__(self, cfg: DecoderConfig, latent_dim: int):
|
|
85
|
+
super().__init__()
|
|
86
|
+
self.cfg = cfg
|
|
87
|
+
self.net = build_mlp(
|
|
88
|
+
in_dim=latent_dim,
|
|
89
|
+
hidden_dims=cfg.hidden_dims,
|
|
90
|
+
out_dim=cfg.output_dim,
|
|
91
|
+
dropout=cfg.dropout,
|
|
92
|
+
batchnorm=cfg.batchnorm,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
96
|
+
log_rate = self.net(z)
|
|
97
|
+
rate = F.softplus(log_rate)
|
|
98
|
+
return {"log_rate": log_rate, "rate": rate}
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class NegativeBinomialDecoder(nn.Module):
|
|
102
|
+
"""z -> {'mu','log_theta'} (theta can be global or gene-wise)."""
|
|
103
|
+
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
cfg: DecoderConfig,
|
|
107
|
+
latent_dim: int,
|
|
108
|
+
dispersion: str = "gene",
|
|
109
|
+
init_log_theta: float = 0.0,
|
|
110
|
+
eps: float = 1e-8,
|
|
111
|
+
):
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.cfg = cfg
|
|
114
|
+
self.dispersion = dispersion
|
|
115
|
+
self.eps = float(eps)
|
|
116
|
+
|
|
117
|
+
self.mu_net = build_mlp(
|
|
118
|
+
in_dim=latent_dim,
|
|
119
|
+
hidden_dims=cfg.hidden_dims,
|
|
120
|
+
out_dim=cfg.output_dim,
|
|
121
|
+
dropout=cfg.dropout,
|
|
122
|
+
batchnorm=cfg.batchnorm,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if dispersion == "global":
|
|
126
|
+
self.log_theta = nn.Parameter(torch.full((1,), float(init_log_theta)))
|
|
127
|
+
elif dispersion == "gene":
|
|
128
|
+
self.log_theta = nn.Parameter(torch.full((cfg.output_dim,), float(init_log_theta)))
|
|
129
|
+
else:
|
|
130
|
+
raise ValueError("Unknown dispersion mode: %r (expected 'global' or 'gene')" % dispersion)
|
|
131
|
+
|
|
132
|
+
def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
133
|
+
mu = F.softplus(self.mu_net(z)) + self.eps
|
|
134
|
+
return {"mu": mu, "log_theta": self.log_theta}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class ZeroInflatedNegativeBinomialDecoder(nn.Module):
|
|
138
|
+
"""z -> {'mu','log_theta','logit_pi'}."""
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
cfg: DecoderConfig,
|
|
143
|
+
latent_dim: int,
|
|
144
|
+
dispersion: str = "gene",
|
|
145
|
+
init_log_theta: float = 0.0,
|
|
146
|
+
eps: float = 1e-8,
|
|
147
|
+
):
|
|
148
|
+
super().__init__()
|
|
149
|
+
self.cfg = cfg
|
|
150
|
+
self.dispersion = dispersion
|
|
151
|
+
self.eps = float(eps)
|
|
152
|
+
|
|
153
|
+
self.backbone = build_mlp(
|
|
154
|
+
in_dim=latent_dim,
|
|
155
|
+
hidden_dims=cfg.hidden_dims,
|
|
156
|
+
out_dim=2 * cfg.output_dim,
|
|
157
|
+
dropout=cfg.dropout,
|
|
158
|
+
batchnorm=cfg.batchnorm,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
if dispersion == "global":
|
|
162
|
+
self.log_theta = nn.Parameter(torch.full((1,), float(init_log_theta)))
|
|
163
|
+
elif dispersion == "gene":
|
|
164
|
+
self.log_theta = nn.Parameter(torch.full((cfg.output_dim,), float(init_log_theta)))
|
|
165
|
+
else:
|
|
166
|
+
raise ValueError("Unknown dispersion mode: %r (expected 'global' or 'gene')" % dispersion)
|
|
167
|
+
|
|
168
|
+
def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
169
|
+
out = self.backbone(z)
|
|
170
|
+
mu_logits, logit_pi = out.chunk(2, dim=-1)
|
|
171
|
+
mu = F.softplus(mu_logits) + self.eps
|
|
172
|
+
return {"mu": mu, "log_theta": self.log_theta, "logit_pi": logit_pi}
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class LogisticNormalDecoder(nn.Module):
|
|
176
|
+
"""z -> {'logits','probs'} for compositions/toy probability vectors."""
|
|
177
|
+
|
|
178
|
+
def __init__(self, cfg: DecoderConfig, latent_dim: int):
|
|
179
|
+
super().__init__()
|
|
180
|
+
self.cfg = cfg
|
|
181
|
+
self.net = build_mlp(
|
|
182
|
+
in_dim=latent_dim,
|
|
183
|
+
hidden_dims=cfg.hidden_dims,
|
|
184
|
+
out_dim=cfg.output_dim,
|
|
185
|
+
dropout=cfg.dropout,
|
|
186
|
+
batchnorm=cfg.batchnorm,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
190
|
+
logits = self.net(z)
|
|
191
|
+
probs = F.softmax(logits, dim=-1)
|
|
192
|
+
return {"logits": logits, "probs": probs}
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class CategoricalDecoder(nn.Module):
|
|
196
|
+
"""z -> {'logits','probs'} for discrete labels."""
|
|
197
|
+
|
|
198
|
+
def __init__(self, cfg: DecoderConfig, latent_dim: int):
|
|
199
|
+
super().__init__()
|
|
200
|
+
self.cfg = cfg
|
|
201
|
+
self.net = build_mlp(
|
|
202
|
+
in_dim=latent_dim,
|
|
203
|
+
hidden_dims=cfg.hidden_dims,
|
|
204
|
+
out_dim=cfg.output_dim,
|
|
205
|
+
dropout=cfg.dropout,
|
|
206
|
+
batchnorm=cfg.batchnorm,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
210
|
+
logits = self.net(z)
|
|
211
|
+
probs = F.softmax(logits, dim=-1)
|
|
212
|
+
return {"logits": logits, "probs": probs}
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
DECODER_REGISTRY = {
|
|
216
|
+
# gaussian
|
|
217
|
+
"gaussian": GaussianDecoder,
|
|
218
|
+
"normal": GaussianDecoder,
|
|
219
|
+
|
|
220
|
+
"gaussian_diag": GaussianDiagDecoder,
|
|
221
|
+
|
|
222
|
+
# bernoulli/poisson
|
|
223
|
+
"bernoulli": BernoulliDecoder,
|
|
224
|
+
"poisson": PoissonDecoder,
|
|
225
|
+
|
|
226
|
+
# count models
|
|
227
|
+
"nb": NegativeBinomialDecoder,
|
|
228
|
+
"negative_binomial": NegativeBinomialDecoder,
|
|
229
|
+
"zinb": ZeroInflatedNegativeBinomialDecoder,
|
|
230
|
+
"zero_inflated_negative_binomial": ZeroInflatedNegativeBinomialDecoder,
|
|
231
|
+
|
|
232
|
+
# compositions / discrete
|
|
233
|
+
"logistic_normal": LogisticNormalDecoder,
|
|
234
|
+
"categorical": CategoricalDecoder,
|
|
235
|
+
"cat": CategoricalDecoder,
|
|
236
|
+
"ce": CategoricalDecoder,
|
|
237
|
+
"cross_entropy": CategoricalDecoder,
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def build_decoder(kind: str, cfg: DecoderConfig, latent_dim: int, **kwargs: Any) -> nn.Module:
|
|
242
|
+
key = str(kind).lower()
|
|
243
|
+
if key not in DECODER_REGISTRY:
|
|
244
|
+
raise ValueError(
|
|
245
|
+
"Unknown decoder kind: %r. Available: %s" % (kind, list(DECODER_REGISTRY.keys()))
|
|
246
|
+
)
|
|
247
|
+
cls = DECODER_REGISTRY[key]
|
|
248
|
+
return cls(cfg=cfg, latent_dim=latent_dim, **kwargs)
|
|
249
|
+
|