univi 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
univi/matching.py ADDED
@@ -0,0 +1,394 @@
1
+ # univi/matching.py
2
+
3
+ import warnings
4
+ import numpy as np
5
+ from sklearn.metrics import pairwise_distances
6
+ from sklearn.neighbors import NearestNeighbors
7
+ from scipy.optimize import linear_sum_assignment
8
+ from typing import Optional, Dict
9
+
10
+
11
+ def _subsample_indices(n: int, max_cells: int, rng: np.random.Generator) -> np.ndarray:
12
+ """
13
+ Helper to subsample up to max_cells indices from range(n) without replacement.
14
+ """
15
+ idx_full = np.arange(n)
16
+ if n <= max_cells:
17
+ return idx_full
18
+ return rng.choice(idx_full, size=max_cells, replace=False)
19
+
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # 1. Basic bipartite matching (Hungarian) in a shared embedding
23
+ # ---------------------------------------------------------------------------
24
+
25
+ def bipartite_match_adata(
26
+ adata_A,
27
+ adata_B,
28
+ emb_key: str = "X_pca",
29
+ metric: str = "euclidean",
30
+ max_cells: int = 20000,
31
+ random_state: int = 0,
32
+ ):
33
+ """
34
+ Bipartite matching between cells in adata_A and adata_B based on a shared embedding.
35
+
36
+ This is the basic "Hungarian in latent space" matcher. It assumes that both
37
+ adata_A and adata_B have a *comparable* embedding in .obsm[emb_key], e.g.:
38
+
39
+ - both are in the same PCA space, or
40
+ - both are in a shared latent space (CCA, UniVI encoder, etc.)
41
+ """
42
+ rng = np.random.default_rng(random_state)
43
+
44
+ XA = np.asarray(adata_A.obsm[emb_key])
45
+ XB = np.asarray(adata_B.obsm[emb_key])
46
+
47
+ na, nb = XA.shape[0], XB.shape[0]
48
+ n = min(na, nb, max_cells)
49
+
50
+ idx_A = _subsample_indices(na, n, rng)
51
+ idx_B = _subsample_indices(nb, n, rng)
52
+
53
+ XA_sub = XA[idx_A]
54
+ XB_sub = XB[idx_B]
55
+
56
+ # cost matrix
57
+ D = pairwise_distances(XA_sub, XB_sub, metric=metric)
58
+
59
+ # Hungarian algorithm (min-cost)
60
+ row_ind, col_ind = linear_sum_assignment(D)
61
+
62
+ matched_A = idx_A[row_ind]
63
+ matched_B = idx_B[col_ind]
64
+
65
+ return matched_A, matched_B
66
+
67
+
68
+ # ---------------------------------------------------------------------------
69
+ # 2. Stratified bipartite matching (per group / cell type)
70
+ # ---------------------------------------------------------------------------
71
+
72
+ def stratified_bipartite_match_adata(
73
+ adata_A,
74
+ adata_B,
75
+ group_key_A: str,
76
+ group_key_B: Optional[str] = None,
77
+ group_map: Optional[Dict] = None,
78
+ emb_key: str = "X_pca",
79
+ metric: str = "euclidean",
80
+ max_cells_per_group: int = 20000,
81
+ random_state: int = 0,
82
+ shuffle: bool = True,
83
+ ):
84
+ """
85
+ Per-group (e.g. per-celltype) bipartite matching in a shared embedding.
86
+
87
+ This wraps `bipartite_match_adata` but runs it separately within each group,
88
+ then concatenates the matches.
89
+ """
90
+ rng = np.random.default_rng(random_state)
91
+
92
+ if group_key_B is None:
93
+ group_key_B = group_key_A
94
+
95
+ if group_key_A not in adata_A.obs.columns:
96
+ raise KeyError(f"{group_key_A!r} not found in adata_A.obs")
97
+ if group_key_B not in adata_B.obs.columns:
98
+ raise KeyError(f"{group_key_B!r} not found in adata_B.obs")
99
+
100
+ labels_A = adata_A.obs[group_key_A].astype(str).to_numpy()
101
+ labels_B = adata_B.obs[group_key_B].astype(str).to_numpy()
102
+
103
+ unique_A = np.unique(labels_A)
104
+
105
+ all_matched_A = []
106
+ all_matched_B = []
107
+ group_counts: Dict[str, int] = {}
108
+
109
+ for gA in unique_A:
110
+ # Determine which label in B we should match to
111
+ if group_map is not None:
112
+ gB = group_map.get(gA, None)
113
+ if gB is None:
114
+ # group not mapped; skip
115
+ continue
116
+ else:
117
+ gB = gA
118
+
119
+ idx_A_g = np.where(labels_A == gA)[0]
120
+ idx_B_g = np.where(labels_B == gB)[0]
121
+
122
+ if (idx_A_g.size == 0) or (idx_B_g.size == 0):
123
+ continue
124
+
125
+ # Build small views
126
+ adata_A_g = adata_A[idx_A_g]
127
+ adata_B_g = adata_B[idx_B_g]
128
+
129
+ # n per group for bipartite matching
130
+ n_grp = min(idx_A_g.size, idx_B_g.size, max_cells_per_group)
131
+ if n_grp == 0:
132
+ continue
133
+
134
+ mA_local, mB_local = bipartite_match_adata(
135
+ adata_A_g,
136
+ adata_B_g,
137
+ emb_key=emb_key,
138
+ metric=metric,
139
+ max_cells=n_grp,
140
+ random_state=random_state,
141
+ )
142
+
143
+ if mA_local.size == 0:
144
+ continue
145
+
146
+ mA = idx_A_g[mA_local]
147
+ mB = idx_B_g[mB_local]
148
+
149
+ all_matched_A.append(mA)
150
+ all_matched_B.append(mB)
151
+ group_counts[gA] = mA.size
152
+
153
+ if not all_matched_A:
154
+ raise RuntimeError("No stratified matches were found for any group.")
155
+
156
+ matched_A = np.concatenate(all_matched_A)
157
+ matched_B = np.concatenate(all_matched_B)
158
+
159
+ if shuffle:
160
+ perm = rng.permutation(matched_A.size)
161
+ matched_A = matched_A[perm]
162
+ matched_B = matched_B[perm]
163
+
164
+ return matched_A, matched_B, group_counts
165
+
166
+
167
+ # ---------------------------------------------------------------------------
168
+ # 3. Mutual Nearest Neighbor (MNN) anchors
169
+ # ---------------------------------------------------------------------------
170
+
171
+ def mnn_anchors_adata(
172
+ adata_A,
173
+ adata_B,
174
+ emb_key: str = "X_pca",
175
+ k: int = 20,
176
+ max_cells: int = 20000,
177
+ random_state: int = 0,
178
+ ):
179
+ """
180
+ Mutual Nearest Neighbor (MNN) anchors between adata_A and adata_B.
181
+ """
182
+ rng = np.random.default_rng(random_state)
183
+
184
+ XA = np.asarray(adata_A.obsm[emb_key])
185
+ XB = np.asarray(adata_B.obsm[emb_key])
186
+
187
+ na, nb = XA.shape[0], XB.shape[0]
188
+
189
+ idx_A = _subsample_indices(na, max_cells, rng)
190
+ idx_B = _subsample_indices(nb, max_cells, rng)
191
+
192
+ XA_sub = XA[idx_A]
193
+ XB_sub = XB[idx_B]
194
+
195
+ k_A = min(k, idx_B.size)
196
+ k_B = min(k, idx_A.size)
197
+
198
+ # A -> B neighbors
199
+ nn_B = NearestNeighbors(n_neighbors=k_A)
200
+ nn_B.fit(XB_sub)
201
+ dist_A2B, ind_A2B = nn_B.kneighbors(XA_sub, return_distance=True)
202
+
203
+ # B -> A neighbors
204
+ nn_A = NearestNeighbors(n_neighbors=k_B)
205
+ nn_A.fit(XA_sub)
206
+ dist_B2A, ind_B2A = nn_A.kneighbors(XB_sub, return_distance=True)
207
+
208
+ # Mutual neighbors
209
+ anchors_A_list = []
210
+ anchors_B_list = []
211
+
212
+ # For quick membership tests: for each j, set of neighbors in A
213
+ neighbors_B2A = [set(ind_B2A[j]) for j in range(idx_B.size)]
214
+
215
+ for i in range(idx_A.size):
216
+ for j_local in ind_A2B[i]:
217
+ # Check mutuality
218
+ if i in neighbors_B2A[j_local]:
219
+ anchors_A_list.append(idx_A[i])
220
+ anchors_B_list.append(idx_B[j_local])
221
+
222
+ if not anchors_A_list:
223
+ warnings.warn("mnn_anchors_adata: no mutual nearest neighbors found.")
224
+ return np.array([], dtype=int), np.array([], dtype=int)
225
+
226
+ anchors_A = np.array(anchors_A_list, dtype=int)
227
+ anchors_B = np.array(anchors_B_list, dtype=int)
228
+
229
+ return anchors_A, anchors_B
230
+
231
+
232
+ # ---------------------------------------------------------------------------
233
+ # 4. Cluster / cell-type centroid matching (for building group maps)
234
+ # ---------------------------------------------------------------------------
235
+
236
+ def cluster_centroid_matching_adata(
237
+ adata_A,
238
+ adata_B,
239
+ group_key_A: str,
240
+ group_key_B: Optional[str] = None,
241
+ emb_key: str = "X_pca",
242
+ metric: str = "euclidean",
243
+ ):
244
+ """
245
+ Match cluster / cell-type centroids across datasets via Hungarian.
246
+ """
247
+ if group_key_B is None:
248
+ group_key_B = group_key_A
249
+
250
+ if group_key_A not in adata_A.obs.columns:
251
+ raise KeyError(f"{group_key_A!r} not found in adata_A.obs")
252
+ if group_key_B not in adata_B.obs.columns:
253
+ raise KeyError(f"{group_key_B!r} not found in adata_B.obs")
254
+
255
+ labels_A = adata_A.obs[group_key_A].astype(str).to_numpy()
256
+ labels_B = adata_B.obs[group_key_B].astype(str).to_numpy()
257
+
258
+ XA = np.asarray(adata_A.obsm[emb_key])
259
+ XB = np.asarray(adata_B.obsm[emb_key])
260
+
261
+ groups_A = np.unique(labels_A)
262
+ groups_B = np.unique(labels_B)
263
+
264
+ # Compute centroids
265
+ centroids_A = []
266
+ for g in groups_A:
267
+ idx = np.where(labels_A == g)[0]
268
+ centroids_A.append(XA[idx].mean(axis=0))
269
+ centroids_A = np.vstack(centroids_A)
270
+
271
+ centroids_B = []
272
+ for g in groups_B:
273
+ idx = np.where(labels_B == g)[0]
274
+ centroids_B.append(XB[idx].mean(axis=0))
275
+ centroids_B = np.vstack(centroids_B)
276
+
277
+ # Cost between centroids
278
+ D = pairwise_distances(centroids_A, centroids_B, metric=metric)
279
+ row_ind, col_ind = linear_sum_assignment(D)
280
+
281
+ group_map: Dict[str, str] = {}
282
+ for i, j in zip(row_ind, col_ind):
283
+ gA = groups_A[i]
284
+ gB = groups_B[j]
285
+ group_map[gA] = gB
286
+
287
+ return group_map
288
+
289
+
290
+ # ---------------------------------------------------------------------------
291
+ # 5. Gromov–Wasserstein OT-based anchors (optional; requires POT)
292
+ # ---------------------------------------------------------------------------
293
+
294
+ def gw_ot_anchors_adata(
295
+ adata_A,
296
+ adata_B,
297
+ emb_key: str = "X_pca",
298
+ max_cells: int = 3000,
299
+ random_state: int = 0,
300
+ normalize_distances: bool = True,
301
+ ):
302
+ """
303
+ Geometry-aware anchors via Gromov–Wasserstein optimal transport.
304
+ """
305
+ try:
306
+ import ot # type: ignore
307
+ except ImportError as e:
308
+ raise ImportError(
309
+ "gw_ot_anchors_adata requires the 'pot' package. "
310
+ "Install with: pip install pot"
311
+ ) from e
312
+
313
+ rng = np.random.default_rng(random_state)
314
+
315
+ XA = np.asarray(adata_A.obsm[emb_key])
316
+ XB = np.asarray(adata_B.obsm[emb_key])
317
+
318
+ na, nb = XA.shape[0], XB.shape[0]
319
+ idx_A = _subsample_indices(na, max_cells, rng)
320
+ idx_B = _subsample_indices(nb, max_cells, rng)
321
+
322
+ XA_sub = XA[idx_A]
323
+ XB_sub = XB[idx_B]
324
+
325
+ # Distance matrices within each dataset
326
+ DA = pairwise_distances(XA_sub, XA_sub, metric="euclidean")
327
+ DB = pairwise_distances(XB_sub, XB_sub, metric="euclidean")
328
+
329
+ if normalize_distances:
330
+ if DA.max() > 0:
331
+ DA = DA / DA.max()
332
+ if DB.max() > 0:
333
+ DB = DB / DB.max()
334
+
335
+ # Uniform weights
336
+ p = np.ones(DA.shape[0]) / DA.shape[0]
337
+ q = np.ones(DB.shape[0]) / DB.shape[0]
338
+
339
+ # Compute GW coupling
340
+ T = ot.gromov.gromov_wasserstein(
341
+ DA, DB, p, q, loss_fun="square_loss", verbose=False
342
+ )
343
+
344
+ # For each i in A, pick the j in B with maximum coupling mass
345
+ anchors_A_list = []
346
+ anchors_B_list = []
347
+ for i in range(T.shape[0]):
348
+ j = int(np.argmax(T[i]))
349
+ anchors_A_list.append(idx_A[i])
350
+ anchors_B_list.append(idx_B[j])
351
+
352
+ anchors_A = np.array(anchors_A_list, dtype=int)
353
+ anchors_B = np.array(anchors_B_list, dtype=int)
354
+
355
+ return anchors_A, anchors_B
356
+
357
+
358
+ # ---------------------------------------------------------------------------
359
+ # 6. Group latent statistics (for distribution-level alignment)
360
+ # ---------------------------------------------------------------------------
361
+
362
+ def group_latent_stats_adata(
363
+ adata,
364
+ group_key: str,
365
+ emb_key: str = "X_pca",
366
+ ):
367
+ """
368
+ Compute simple group-wise latent statistics (mean, covariance) to
369
+ support distribution-level alignment strategies.
370
+ """
371
+ if group_key not in adata.obs.columns:
372
+ raise KeyError(f"{group_key!r} not found in adata.obs")
373
+
374
+ labels = adata.obs[group_key].astype(str).to_numpy()
375
+ X = np.asarray(adata.obsm[emb_key])
376
+
377
+ groups = np.unique(labels)
378
+ stats = {}
379
+
380
+ for g in groups:
381
+ idx = np.where(labels == g)[0]
382
+ if idx.size == 0:
383
+ continue
384
+ Xg = X[idx]
385
+ mu = Xg.mean(axis=0)
386
+ # rowvar=False so that columns are variables
387
+ cov = np.cov(Xg, rowvar=False)
388
+ stats[g] = {
389
+ "mean": mu,
390
+ "cov": cov,
391
+ "n": idx.size,
392
+ }
393
+
394
+ return stats
@@ -0,0 +1,8 @@
1
+ # univi/models/__init__.py
2
+ from __future__ import annotations
3
+
4
+ from .univi import UniVIMultiModalVAE
5
+ from .transformer import TransformerEncoder
6
+ from .tokenizers import build_tokenizer
7
+
8
+ __all__ = ["UniVIMultiModalVAE", "TransformerEncoder", "build_tokenizer"]
@@ -0,0 +1,249 @@
1
+ # univi/models/decoders.py
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Dict, List
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+
12
+ from .mlp import build_mlp
13
+
14
+
15
+ @dataclass
16
+ class DecoderConfig:
17
+ """Generic configuration for feed-forward decoders."""
18
+ output_dim: int
19
+ hidden_dims: List[int]
20
+ dropout: float = 0.0
21
+ batchnorm: bool = False
22
+
23
+
24
+ class GaussianDecoder(nn.Module):
25
+ """z -> mean reconstruction (use with MSE/Gaussian losses)."""
26
+
27
+ def __init__(self, cfg: DecoderConfig, latent_dim: int):
28
+ super().__init__()
29
+ self.cfg = cfg
30
+ self.net = build_mlp(
31
+ in_dim=latent_dim,
32
+ hidden_dims=cfg.hidden_dims,
33
+ out_dim=cfg.output_dim,
34
+ dropout=cfg.dropout,
35
+ batchnorm=cfg.batchnorm,
36
+ )
37
+
38
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
39
+ return self.net(z)
40
+
41
+
42
+ class GaussianDiagDecoder(nn.Module):
43
+ """z -> {'mean','logvar'} for full diagonal Gaussian likelihoods."""
44
+
45
+ def __init__(self, cfg: DecoderConfig, latent_dim: int):
46
+ super().__init__()
47
+ self.cfg = cfg
48
+ self.backbone = build_mlp(
49
+ in_dim=latent_dim,
50
+ hidden_dims=cfg.hidden_dims,
51
+ out_dim=2 * cfg.output_dim,
52
+ dropout=cfg.dropout,
53
+ batchnorm=cfg.batchnorm,
54
+ )
55
+
56
+ def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
57
+ out = self.backbone(z)
58
+ mean, logvar = out.chunk(2, dim=-1)
59
+ return {"mean": mean, "logvar": logvar}
60
+
61
+
62
+ class BernoulliDecoder(nn.Module):
63
+ """z -> {'logits'} for Bernoulli likelihoods."""
64
+
65
+ def __init__(self, cfg: DecoderConfig, latent_dim: int):
66
+ super().__init__()
67
+ self.cfg = cfg
68
+ self.net = build_mlp(
69
+ in_dim=latent_dim,
70
+ hidden_dims=cfg.hidden_dims,
71
+ out_dim=cfg.output_dim,
72
+ dropout=cfg.dropout,
73
+ batchnorm=cfg.batchnorm,
74
+ )
75
+
76
+ def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
77
+ logits = self.net(z)
78
+ return {"logits": logits}
79
+
80
+
81
+ class PoissonDecoder(nn.Module):
82
+ """z -> {'log_rate','rate'} for Poisson likelihoods."""
83
+
84
+ def __init__(self, cfg: DecoderConfig, latent_dim: int):
85
+ super().__init__()
86
+ self.cfg = cfg
87
+ self.net = build_mlp(
88
+ in_dim=latent_dim,
89
+ hidden_dims=cfg.hidden_dims,
90
+ out_dim=cfg.output_dim,
91
+ dropout=cfg.dropout,
92
+ batchnorm=cfg.batchnorm,
93
+ )
94
+
95
+ def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
96
+ log_rate = self.net(z)
97
+ rate = F.softplus(log_rate)
98
+ return {"log_rate": log_rate, "rate": rate}
99
+
100
+
101
+ class NegativeBinomialDecoder(nn.Module):
102
+ """z -> {'mu','log_theta'} (theta can be global or gene-wise)."""
103
+
104
+ def __init__(
105
+ self,
106
+ cfg: DecoderConfig,
107
+ latent_dim: int,
108
+ dispersion: str = "gene",
109
+ init_log_theta: float = 0.0,
110
+ eps: float = 1e-8,
111
+ ):
112
+ super().__init__()
113
+ self.cfg = cfg
114
+ self.dispersion = dispersion
115
+ self.eps = float(eps)
116
+
117
+ self.mu_net = build_mlp(
118
+ in_dim=latent_dim,
119
+ hidden_dims=cfg.hidden_dims,
120
+ out_dim=cfg.output_dim,
121
+ dropout=cfg.dropout,
122
+ batchnorm=cfg.batchnorm,
123
+ )
124
+
125
+ if dispersion == "global":
126
+ self.log_theta = nn.Parameter(torch.full((1,), float(init_log_theta)))
127
+ elif dispersion == "gene":
128
+ self.log_theta = nn.Parameter(torch.full((cfg.output_dim,), float(init_log_theta)))
129
+ else:
130
+ raise ValueError("Unknown dispersion mode: %r (expected 'global' or 'gene')" % dispersion)
131
+
132
+ def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
133
+ mu = F.softplus(self.mu_net(z)) + self.eps
134
+ return {"mu": mu, "log_theta": self.log_theta}
135
+
136
+
137
+ class ZeroInflatedNegativeBinomialDecoder(nn.Module):
138
+ """z -> {'mu','log_theta','logit_pi'}."""
139
+
140
+ def __init__(
141
+ self,
142
+ cfg: DecoderConfig,
143
+ latent_dim: int,
144
+ dispersion: str = "gene",
145
+ init_log_theta: float = 0.0,
146
+ eps: float = 1e-8,
147
+ ):
148
+ super().__init__()
149
+ self.cfg = cfg
150
+ self.dispersion = dispersion
151
+ self.eps = float(eps)
152
+
153
+ self.backbone = build_mlp(
154
+ in_dim=latent_dim,
155
+ hidden_dims=cfg.hidden_dims,
156
+ out_dim=2 * cfg.output_dim,
157
+ dropout=cfg.dropout,
158
+ batchnorm=cfg.batchnorm,
159
+ )
160
+
161
+ if dispersion == "global":
162
+ self.log_theta = nn.Parameter(torch.full((1,), float(init_log_theta)))
163
+ elif dispersion == "gene":
164
+ self.log_theta = nn.Parameter(torch.full((cfg.output_dim,), float(init_log_theta)))
165
+ else:
166
+ raise ValueError("Unknown dispersion mode: %r (expected 'global' or 'gene')" % dispersion)
167
+
168
+ def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
169
+ out = self.backbone(z)
170
+ mu_logits, logit_pi = out.chunk(2, dim=-1)
171
+ mu = F.softplus(mu_logits) + self.eps
172
+ return {"mu": mu, "log_theta": self.log_theta, "logit_pi": logit_pi}
173
+
174
+
175
+ class LogisticNormalDecoder(nn.Module):
176
+ """z -> {'logits','probs'} for compositions/toy probability vectors."""
177
+
178
+ def __init__(self, cfg: DecoderConfig, latent_dim: int):
179
+ super().__init__()
180
+ self.cfg = cfg
181
+ self.net = build_mlp(
182
+ in_dim=latent_dim,
183
+ hidden_dims=cfg.hidden_dims,
184
+ out_dim=cfg.output_dim,
185
+ dropout=cfg.dropout,
186
+ batchnorm=cfg.batchnorm,
187
+ )
188
+
189
+ def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
190
+ logits = self.net(z)
191
+ probs = F.softmax(logits, dim=-1)
192
+ return {"logits": logits, "probs": probs}
193
+
194
+
195
+ class CategoricalDecoder(nn.Module):
196
+ """z -> {'logits','probs'} for discrete labels."""
197
+
198
+ def __init__(self, cfg: DecoderConfig, latent_dim: int):
199
+ super().__init__()
200
+ self.cfg = cfg
201
+ self.net = build_mlp(
202
+ in_dim=latent_dim,
203
+ hidden_dims=cfg.hidden_dims,
204
+ out_dim=cfg.output_dim,
205
+ dropout=cfg.dropout,
206
+ batchnorm=cfg.batchnorm,
207
+ )
208
+
209
+ def forward(self, z: torch.Tensor) -> Dict[str, torch.Tensor]:
210
+ logits = self.net(z)
211
+ probs = F.softmax(logits, dim=-1)
212
+ return {"logits": logits, "probs": probs}
213
+
214
+
215
+ DECODER_REGISTRY = {
216
+ # gaussian
217
+ "gaussian": GaussianDecoder,
218
+ "normal": GaussianDecoder,
219
+
220
+ "gaussian_diag": GaussianDiagDecoder,
221
+
222
+ # bernoulli/poisson
223
+ "bernoulli": BernoulliDecoder,
224
+ "poisson": PoissonDecoder,
225
+
226
+ # count models
227
+ "nb": NegativeBinomialDecoder,
228
+ "negative_binomial": NegativeBinomialDecoder,
229
+ "zinb": ZeroInflatedNegativeBinomialDecoder,
230
+ "zero_inflated_negative_binomial": ZeroInflatedNegativeBinomialDecoder,
231
+
232
+ # compositions / discrete
233
+ "logistic_normal": LogisticNormalDecoder,
234
+ "categorical": CategoricalDecoder,
235
+ "cat": CategoricalDecoder,
236
+ "ce": CategoricalDecoder,
237
+ "cross_entropy": CategoricalDecoder,
238
+ }
239
+
240
+
241
+ def build_decoder(kind: str, cfg: DecoderConfig, latent_dim: int, **kwargs: Any) -> nn.Module:
242
+ key = str(kind).lower()
243
+ if key not in DECODER_REGISTRY:
244
+ raise ValueError(
245
+ "Unknown decoder kind: %r. Available: %s" % (kind, list(DECODER_REGISTRY.keys()))
246
+ )
247
+ cls = DECODER_REGISTRY[key]
248
+ return cls(cfg=cfg, latent_dim=latent_dim, **kwargs)
249
+