univi 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- univi/__init__.py +120 -0
- univi/__main__.py +5 -0
- univi/cli.py +60 -0
- univi/config.py +340 -0
- univi/data.py +345 -0
- univi/diagnostics.py +130 -0
- univi/evaluation.py +632 -0
- univi/hyperparam_optimization/__init__.py +17 -0
- univi/hyperparam_optimization/common.py +339 -0
- univi/hyperparam_optimization/run_adt_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_atac_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_citeseq_hparam_search.py +137 -0
- univi/hyperparam_optimization/run_multiome_hparam_search.py +145 -0
- univi/hyperparam_optimization/run_rna_hparam_search.py +111 -0
- univi/hyperparam_optimization/run_teaseq_hparam_search.py +146 -0
- univi/interpretability.py +399 -0
- univi/matching.py +394 -0
- univi/models/__init__.py +8 -0
- univi/models/decoders.py +249 -0
- univi/models/encoders.py +848 -0
- univi/models/mlp.py +36 -0
- univi/models/tokenizers.py +376 -0
- univi/models/transformer.py +249 -0
- univi/models/univi.py +1284 -0
- univi/objectives.py +46 -0
- univi/pipeline.py +194 -0
- univi/plotting.py +126 -0
- univi/trainer.py +478 -0
- univi/utils/__init__.py +5 -0
- univi/utils/io.py +621 -0
- univi/utils/logging.py +16 -0
- univi/utils/seed.py +18 -0
- univi/utils/stats.py +23 -0
- univi/utils/torch_utils.py +23 -0
- univi-0.3.4.dist-info/METADATA +908 -0
- univi-0.3.4.dist-info/RECORD +40 -0
- univi-0.3.4.dist-info/WHEEL +5 -0
- univi-0.3.4.dist-info/entry_points.txt +2 -0
- univi-0.3.4.dist-info/licenses/LICENSE +21 -0
- univi-0.3.4.dist-info/top_level.txt +1 -0
univi/models/mlp.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# univi/models/mlp.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def build_mlp(
|
|
9
|
+
in_dim: int,
|
|
10
|
+
hidden_dims: List[int],
|
|
11
|
+
out_dim: int,
|
|
12
|
+
activation: Optional[nn.Module] = None,
|
|
13
|
+
dropout: float = 0.0,
|
|
14
|
+
batchnorm: bool = True,
|
|
15
|
+
) -> nn.Sequential:
|
|
16
|
+
"""
|
|
17
|
+
Generic MLP builder: [Linear -> BN -> Act -> Dropout]* + final Linear.
|
|
18
|
+
(Python gotcha: don't use nn.ReLU() as a default arg; it becomes a shared instance.)
|
|
19
|
+
"""
|
|
20
|
+
if activation is None:
|
|
21
|
+
activation = nn.ReLU()
|
|
22
|
+
|
|
23
|
+
layers = []
|
|
24
|
+
last_dim = in_dim
|
|
25
|
+
for h in hidden_dims:
|
|
26
|
+
layers.append(nn.Linear(last_dim, h))
|
|
27
|
+
if batchnorm:
|
|
28
|
+
layers.append(nn.BatchNorm1d(h))
|
|
29
|
+
layers.append(activation.__class__() if isinstance(activation, nn.Module) else nn.ReLU())
|
|
30
|
+
if dropout and dropout > 0:
|
|
31
|
+
layers.append(nn.Dropout(float(dropout)))
|
|
32
|
+
last_dim = h
|
|
33
|
+
|
|
34
|
+
layers.append(nn.Linear(last_dim, out_dim))
|
|
35
|
+
return nn.Sequential(*layers)
|
|
36
|
+
|
|
@@ -0,0 +1,376 @@
|
|
|
1
|
+
# univi/models/tokenizers.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Optional, Tuple, Sequence, Literal, Dict, Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
from ..config import TokenizerConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Tokenizer(nn.Module):
|
|
13
|
+
"""
|
|
14
|
+
Base tokenizer interface.
|
|
15
|
+
|
|
16
|
+
Backwards-compatible:
|
|
17
|
+
forward(x) -> (tokens, key_padding_mask)
|
|
18
|
+
|
|
19
|
+
Extras:
|
|
20
|
+
- self.last_meta is updated on each forward()
|
|
21
|
+
- forward_with_meta(x) -> (tokens, key_padding_mask, meta)
|
|
22
|
+
|
|
23
|
+
Conventions
|
|
24
|
+
-----------
|
|
25
|
+
- tokens: (B, T, D_in)
|
|
26
|
+
- key_padding_mask: Optional[(B, T)] where True means "PAD / ignore"
|
|
27
|
+
- meta: dict (optional), e.g. {"token_pos": (B, T) basepair positions}
|
|
28
|
+
"""
|
|
29
|
+
def __init__(self):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.last_meta: Dict[str, Any] = {}
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def d_in(self) -> int:
|
|
35
|
+
raise NotImplementedError
|
|
36
|
+
|
|
37
|
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
def forward_with_meta(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor], Dict[str, Any]]:
|
|
41
|
+
tokens, mask = self.forward(x)
|
|
42
|
+
return tokens, mask, dict(self.last_meta)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class TopKScalarTokenizer(Tokenizer):
|
|
46
|
+
"""(B,F) -> (B,K,1) using top-k by absolute value per cell."""
|
|
47
|
+
def __init__(self, n_tokens: int, add_cls_token: bool = False):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.n_tokens = int(n_tokens)
|
|
50
|
+
self.add_cls_token = bool(add_cls_token)
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def d_in(self) -> int:
|
|
54
|
+
return 1
|
|
55
|
+
|
|
56
|
+
def forward(self, x: torch.Tensor):
|
|
57
|
+
B, F = x.shape
|
|
58
|
+
K = min(self.n_tokens, F)
|
|
59
|
+
|
|
60
|
+
_, idx = torch.topk(x.abs(), k=K, dim=1, largest=True, sorted=True)
|
|
61
|
+
vals = torch.gather(x, 1, idx) # (B,K)
|
|
62
|
+
tokens = vals.unsqueeze(-1) # (B,K,1)
|
|
63
|
+
|
|
64
|
+
key_padding_mask = None
|
|
65
|
+
self.last_meta = {"feature_idx": idx}
|
|
66
|
+
|
|
67
|
+
if self.add_cls_token:
|
|
68
|
+
cls = torch.zeros((B, 1, 1), device=x.device, dtype=x.dtype)
|
|
69
|
+
tokens = torch.cat([cls, tokens], dim=1)
|
|
70
|
+
|
|
71
|
+
return tokens, key_padding_mask
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class TopKChannelsTokenizer(Tokenizer):
|
|
75
|
+
"""
|
|
76
|
+
(B,F) -> (B,K,C) multi-dim tokens, where channels can include:
|
|
77
|
+
- value: raw x_i
|
|
78
|
+
- rank: rank within selected K (0..1)
|
|
79
|
+
- dropout: 1 if x_i == 0 else 0
|
|
80
|
+
"""
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
n_tokens: int,
|
|
84
|
+
channels: Sequence[Literal["value", "rank", "dropout"]] = ("value", "rank", "dropout"),
|
|
85
|
+
add_cls_token: bool = False,
|
|
86
|
+
):
|
|
87
|
+
super().__init__()
|
|
88
|
+
self.n_tokens = int(n_tokens)
|
|
89
|
+
self.channels = tuple(channels)
|
|
90
|
+
self.add_cls_token = bool(add_cls_token)
|
|
91
|
+
|
|
92
|
+
if len(self.channels) == 0:
|
|
93
|
+
raise ValueError("TopKChannelsTokenizer requires at least one channel.")
|
|
94
|
+
for c in self.channels:
|
|
95
|
+
if c not in ("value", "rank", "dropout"):
|
|
96
|
+
raise ValueError(f"Unknown channel {c!r}. Allowed: value, rank, dropout")
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def d_in(self) -> int:
|
|
100
|
+
return len(self.channels)
|
|
101
|
+
|
|
102
|
+
def forward(self, x: torch.Tensor):
|
|
103
|
+
B, F = x.shape
|
|
104
|
+
K = min(self.n_tokens, F)
|
|
105
|
+
|
|
106
|
+
_, idx = torch.topk(x.abs(), k=K, dim=1, largest=True, sorted=True)
|
|
107
|
+
vals = torch.gather(x, 1, idx) # (B,K)
|
|
108
|
+
|
|
109
|
+
chans = []
|
|
110
|
+
for c in self.channels:
|
|
111
|
+
if c == "value":
|
|
112
|
+
chans.append(vals)
|
|
113
|
+
elif c == "dropout":
|
|
114
|
+
chans.append((vals == 0).to(vals.dtype))
|
|
115
|
+
elif c == "rank":
|
|
116
|
+
r = torch.arange(K, device=x.device, dtype=vals.dtype).view(1, K).expand(B, K)
|
|
117
|
+
chans.append(r / max(K - 1, 1))
|
|
118
|
+
else:
|
|
119
|
+
raise RuntimeError("unreachable")
|
|
120
|
+
|
|
121
|
+
tokens = torch.stack(chans, dim=-1) # (B,K,C)
|
|
122
|
+
key_padding_mask = None
|
|
123
|
+
self.last_meta = {"feature_idx": idx}
|
|
124
|
+
|
|
125
|
+
if self.add_cls_token:
|
|
126
|
+
cls = torch.zeros((B, 1, tokens.size(-1)), device=x.device, dtype=x.dtype)
|
|
127
|
+
tokens = torch.cat([cls, tokens], dim=1)
|
|
128
|
+
|
|
129
|
+
return tokens, key_padding_mask
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class PatchTokenizer(Tokenizer):
|
|
133
|
+
"""
|
|
134
|
+
Split features into patches:
|
|
135
|
+
|
|
136
|
+
(B,F) -> (B,T,patch_size) where T = ceil(F/patch_size)
|
|
137
|
+
|
|
138
|
+
Optionally project:
|
|
139
|
+
patch_vec (patch_size) -> patch_proj_dim
|
|
140
|
+
"""
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
patch_size: int,
|
|
144
|
+
add_cls_token: bool = False,
|
|
145
|
+
patch_proj_dim: Optional[int] = None,
|
|
146
|
+
):
|
|
147
|
+
super().__init__()
|
|
148
|
+
self.patch_size = int(patch_size)
|
|
149
|
+
self.add_cls_token = bool(add_cls_token)
|
|
150
|
+
self.patch_proj_dim = int(patch_proj_dim) if patch_proj_dim is not None else None
|
|
151
|
+
|
|
152
|
+
if self.patch_size <= 0:
|
|
153
|
+
raise ValueError("patch_size must be > 0")
|
|
154
|
+
|
|
155
|
+
if self.patch_proj_dim is not None:
|
|
156
|
+
self.proj = nn.Sequential(
|
|
157
|
+
nn.LayerNorm(self.patch_size),
|
|
158
|
+
nn.Linear(self.patch_size, self.patch_proj_dim),
|
|
159
|
+
nn.GELU(),
|
|
160
|
+
nn.Linear(self.patch_proj_dim, self.patch_proj_dim),
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
self.proj = None
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def d_in(self) -> int:
|
|
167
|
+
return self.patch_proj_dim if self.patch_proj_dim is not None else self.patch_size
|
|
168
|
+
|
|
169
|
+
def forward(self, x: torch.Tensor):
|
|
170
|
+
B, F = x.shape
|
|
171
|
+
P = self.patch_size
|
|
172
|
+
T = (F + P - 1) // P
|
|
173
|
+
pad = T * P - F
|
|
174
|
+
|
|
175
|
+
if pad > 0:
|
|
176
|
+
x_pad = torch.cat([x, torch.zeros((B, pad), device=x.device, dtype=x.dtype)], dim=1)
|
|
177
|
+
else:
|
|
178
|
+
x_pad = x
|
|
179
|
+
|
|
180
|
+
patches = x_pad.view(B, T, P) # (B,T,P)
|
|
181
|
+
key_padding_mask = None
|
|
182
|
+
|
|
183
|
+
if self.proj is not None:
|
|
184
|
+
patches = self.proj(patches) # (B,T,patch_proj_dim)
|
|
185
|
+
|
|
186
|
+
if self.add_cls_token:
|
|
187
|
+
cls = torch.zeros((B, 1, patches.size(-1)), device=x.device, dtype=x.dtype)
|
|
188
|
+
patches = torch.cat([cls, patches], dim=1)
|
|
189
|
+
|
|
190
|
+
self.last_meta = {}
|
|
191
|
+
return patches, key_padding_mask
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class TopKEmbeddedTokenizer(Tokenizer):
|
|
195
|
+
"""
|
|
196
|
+
Top-k tokenizer with explicit feature identity embeddings:
|
|
197
|
+
|
|
198
|
+
token = Emb(feature_id) + MLP(channels(value/rank/dropout))
|
|
199
|
+
|
|
200
|
+
Optional ATAC coordinate embeddings:
|
|
201
|
+
token += Emb(chrom_id) + MLP(midpoint_bp / coord_scale)
|
|
202
|
+
|
|
203
|
+
Meta
|
|
204
|
+
----
|
|
205
|
+
self.last_meta will include:
|
|
206
|
+
- "feature_idx": (B,K) long
|
|
207
|
+
- "token_pos": (B,K) float basepairs (if use_coords=True)
|
|
208
|
+
"""
|
|
209
|
+
def __init__(
|
|
210
|
+
self,
|
|
211
|
+
*,
|
|
212
|
+
n_tokens: int,
|
|
213
|
+
n_features: int,
|
|
214
|
+
d_model: int,
|
|
215
|
+
channels: Sequence[Literal["value", "rank", "dropout"]] = ("value", "rank", "dropout"),
|
|
216
|
+
add_cls_token: bool = False,
|
|
217
|
+
value_mlp_hidden: int = 256,
|
|
218
|
+
# coordinate extras
|
|
219
|
+
use_coords: bool = False,
|
|
220
|
+
chrom_vocab_size: int = 0,
|
|
221
|
+
feature_info: Optional[Dict[str, Any]] = None,
|
|
222
|
+
coord_scale: float = 1e6,
|
|
223
|
+
):
|
|
224
|
+
super().__init__()
|
|
225
|
+
self.n_tokens = int(n_tokens)
|
|
226
|
+
self.n_features = int(n_features)
|
|
227
|
+
self._d_model = int(d_model)
|
|
228
|
+
self.channels = tuple(channels)
|
|
229
|
+
self.add_cls_token = bool(add_cls_token)
|
|
230
|
+
self.use_coords = bool(use_coords)
|
|
231
|
+
self.chrom_vocab_size = int(chrom_vocab_size)
|
|
232
|
+
self.coord_scale = float(coord_scale)
|
|
233
|
+
|
|
234
|
+
if len(self.channels) == 0:
|
|
235
|
+
raise ValueError("TopKEmbeddedTokenizer requires at least one channel.")
|
|
236
|
+
for c in self.channels:
|
|
237
|
+
if c not in ("value", "rank", "dropout"):
|
|
238
|
+
raise ValueError(f"Unknown channel {c!r}. Allowed: value, rank, dropout")
|
|
239
|
+
|
|
240
|
+
self.id_embed = nn.Embedding(self.n_features, self._d_model)
|
|
241
|
+
|
|
242
|
+
c_in = len(self.channels)
|
|
243
|
+
self.val_proj = nn.Sequential(
|
|
244
|
+
nn.LayerNorm(c_in),
|
|
245
|
+
nn.Linear(c_in, int(value_mlp_hidden)),
|
|
246
|
+
nn.GELU(),
|
|
247
|
+
nn.Linear(int(value_mlp_hidden), self._d_model),
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Feature metadata buffers for coords
|
|
251
|
+
self.feature_chrom: Optional[torch.Tensor] = None
|
|
252
|
+
self.feature_start: Optional[torch.Tensor] = None
|
|
253
|
+
self.feature_end: Optional[torch.Tensor] = None
|
|
254
|
+
|
|
255
|
+
if self.use_coords:
|
|
256
|
+
if self.chrom_vocab_size <= 0:
|
|
257
|
+
raise ValueError("chrom_vocab_size must be > 0 when use_coords=True.")
|
|
258
|
+
if feature_info is None:
|
|
259
|
+
raise ValueError("feature_info must be provided when use_coords=True (keys: chrom,start,end).")
|
|
260
|
+
for k in ("chrom", "start", "end"):
|
|
261
|
+
if k not in feature_info:
|
|
262
|
+
raise ValueError(f"feature_info missing key {k!r} (required for coords).")
|
|
263
|
+
|
|
264
|
+
chrom = torch.as_tensor(feature_info["chrom"], dtype=torch.long)
|
|
265
|
+
start = torch.as_tensor(feature_info["start"], dtype=torch.float32)
|
|
266
|
+
end = torch.as_tensor(feature_info["end"], dtype=torch.float32)
|
|
267
|
+
|
|
268
|
+
if chrom.numel() != self.n_features or start.numel() != self.n_features or end.numel() != self.n_features:
|
|
269
|
+
raise ValueError(
|
|
270
|
+
f"feature_info arrays must have length n_features={self.n_features}; "
|
|
271
|
+
f"got chrom={chrom.numel()}, start={start.numel()}, end={end.numel()}."
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
# register buffers so they follow .to(device)
|
|
275
|
+
self.register_buffer("feature_chrom", chrom, persistent=False)
|
|
276
|
+
self.register_buffer("feature_start", start, persistent=False)
|
|
277
|
+
self.register_buffer("feature_end", end, persistent=False)
|
|
278
|
+
|
|
279
|
+
self.chrom_embed = nn.Embedding(self.chrom_vocab_size, self._d_model)
|
|
280
|
+
self.coord_mlp = nn.Sequential(
|
|
281
|
+
nn.LayerNorm(1),
|
|
282
|
+
nn.Linear(1, int(value_mlp_hidden)),
|
|
283
|
+
nn.GELU(),
|
|
284
|
+
nn.Linear(int(value_mlp_hidden), self._d_model),
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
@property
|
|
288
|
+
def d_in(self) -> int:
|
|
289
|
+
return self._d_model
|
|
290
|
+
|
|
291
|
+
def forward(self, x: torch.Tensor):
|
|
292
|
+
B, F = x.shape
|
|
293
|
+
if F != self.n_features:
|
|
294
|
+
raise ValueError(f"Expected F={self.n_features}, got {F}. Did you set TokenizerConfig.n_features correctly?")
|
|
295
|
+
|
|
296
|
+
K = min(self.n_tokens, F)
|
|
297
|
+
|
|
298
|
+
_, idx = torch.topk(x.abs(), k=K, dim=1, largest=True, sorted=True) # (B,K)
|
|
299
|
+
vals = torch.gather(x, 1, idx) # (B,K)
|
|
300
|
+
|
|
301
|
+
# channels -> (B,K,C)
|
|
302
|
+
chans = []
|
|
303
|
+
for c in self.channels:
|
|
304
|
+
if c == "value":
|
|
305
|
+
chans.append(vals)
|
|
306
|
+
elif c == "dropout":
|
|
307
|
+
chans.append((vals == 0).to(vals.dtype))
|
|
308
|
+
elif c == "rank":
|
|
309
|
+
r = torch.arange(K, device=x.device, dtype=vals.dtype).view(1, K).expand(B, K)
|
|
310
|
+
chans.append(r / max(K - 1, 1))
|
|
311
|
+
else:
|
|
312
|
+
raise RuntimeError("unreachable")
|
|
313
|
+
|
|
314
|
+
ch = torch.stack(chans, dim=-1) # (B,K,C)
|
|
315
|
+
|
|
316
|
+
id_emb = self.id_embed(idx) # (B,K,D)
|
|
317
|
+
val_emb = self.val_proj(ch) # (B,K,D)
|
|
318
|
+
tokens = id_emb + val_emb
|
|
319
|
+
|
|
320
|
+
meta: Dict[str, Any] = {"feature_idx": idx}
|
|
321
|
+
|
|
322
|
+
if self.use_coords:
|
|
323
|
+
# buffers exist because we register_buffer above
|
|
324
|
+
chrom = self.feature_chrom[idx] # (B,K)
|
|
325
|
+
mid = 0.5 * (self.feature_start[idx] + self.feature_end[idx]) # (B,K)
|
|
326
|
+
mid_scaled = (mid / self.coord_scale).unsqueeze(-1) # (B,K,1)
|
|
327
|
+
|
|
328
|
+
tokens = tokens + self.chrom_embed(chrom) + self.coord_mlp(mid_scaled)
|
|
329
|
+
meta["token_pos"] = mid # basepairs
|
|
330
|
+
|
|
331
|
+
if self.add_cls_token:
|
|
332
|
+
cls = torch.zeros((B, 1, tokens.size(-1)), device=x.device, dtype=x.dtype)
|
|
333
|
+
tokens = torch.cat([cls, tokens], dim=1)
|
|
334
|
+
# keep meta aligned if present
|
|
335
|
+
if "token_pos" in meta:
|
|
336
|
+
cls_pos = torch.zeros((B, 1), device=x.device, dtype=meta["token_pos"].dtype)
|
|
337
|
+
meta["token_pos"] = torch.cat([cls_pos, meta["token_pos"]], dim=1)
|
|
338
|
+
|
|
339
|
+
self.last_meta = meta
|
|
340
|
+
return tokens, None
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def build_tokenizer(cfg: TokenizerConfig) -> Tokenizer:
|
|
344
|
+
mode = (cfg.mode or "").lower().strip()
|
|
345
|
+
|
|
346
|
+
if mode == "topk_scalar":
|
|
347
|
+
return TopKScalarTokenizer(n_tokens=cfg.n_tokens, add_cls_token=cfg.add_cls_token)
|
|
348
|
+
|
|
349
|
+
if mode == "topk_channels":
|
|
350
|
+
return TopKChannelsTokenizer(n_tokens=cfg.n_tokens, channels=cfg.channels, add_cls_token=cfg.add_cls_token)
|
|
351
|
+
|
|
352
|
+
if mode == "patch":
|
|
353
|
+
return PatchTokenizer(
|
|
354
|
+
patch_size=cfg.patch_size,
|
|
355
|
+
add_cls_token=cfg.add_cls_token,
|
|
356
|
+
patch_proj_dim=cfg.patch_proj_dim,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
if mode == "topk_embed":
|
|
360
|
+
if cfg.n_features is None or cfg.d_model is None:
|
|
361
|
+
raise ValueError("TokenizerConfig.mode='topk_embed' requires n_features and d_model to be set.")
|
|
362
|
+
return TopKEmbeddedTokenizer(
|
|
363
|
+
n_tokens=cfg.n_tokens,
|
|
364
|
+
n_features=int(cfg.n_features),
|
|
365
|
+
d_model=int(cfg.d_model),
|
|
366
|
+
channels=cfg.channels,
|
|
367
|
+
add_cls_token=cfg.add_cls_token,
|
|
368
|
+
value_mlp_hidden=int(cfg.value_mlp_hidden),
|
|
369
|
+
use_coords=bool(cfg.use_coords),
|
|
370
|
+
chrom_vocab_size=int(cfg.chrom_vocab_size),
|
|
371
|
+
feature_info=cfg.feature_info,
|
|
372
|
+
coord_scale=float(cfg.coord_scale),
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
raise ValueError(f"Unknown tokenizer mode {cfg.mode!r}")
|
|
376
|
+
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
# univi/models/transformer.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Optional, Literal, List, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class TransformerConfig:
|
|
14
|
+
d_model: int
|
|
15
|
+
num_heads: int
|
|
16
|
+
num_layers: int
|
|
17
|
+
dim_feedforward: int = 4096
|
|
18
|
+
dropout: float = 0.1
|
|
19
|
+
attn_dropout: float = 0.1
|
|
20
|
+
activation: Literal["relu", "gelu"] = "gelu"
|
|
21
|
+
pooling: Literal["cls", "mean"] = "mean"
|
|
22
|
+
max_tokens: Optional[int] = None
|
|
23
|
+
|
|
24
|
+
# Optional: binned relative-position attention bias (e.g., genomic distance)
|
|
25
|
+
use_relpos_bias: bool = False
|
|
26
|
+
relpos_num_bins: int = 32
|
|
27
|
+
relpos_max_dist: float = 1e6 # basepairs
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _act(name: str):
|
|
31
|
+
name = str(name).lower().strip()
|
|
32
|
+
if name == "relu":
|
|
33
|
+
return F.relu
|
|
34
|
+
if name == "gelu":
|
|
35
|
+
return F.gelu
|
|
36
|
+
raise ValueError(f"Unknown activation: {name!r}")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class GenomicRelPosBias(nn.Module):
|
|
40
|
+
"""
|
|
41
|
+
Simple distance-binned relative attention bias.
|
|
42
|
+
|
|
43
|
+
Given token positions pos (B,T) in basepairs, returns an additive bias
|
|
44
|
+
(B, H, T, T). Intended for ATAC peak midpoints.
|
|
45
|
+
|
|
46
|
+
Notes
|
|
47
|
+
-----
|
|
48
|
+
- Uses log1p compression to allocate more bins to shorter distances.
|
|
49
|
+
- Bias table is learned: (H, num_bins).
|
|
50
|
+
"""
|
|
51
|
+
def __init__(self, num_heads: int, num_bins: int = 32, max_dist: float = 1e6):
|
|
52
|
+
super().__init__()
|
|
53
|
+
self.num_heads = int(num_heads)
|
|
54
|
+
self.num_bins = int(num_bins)
|
|
55
|
+
self.max_dist = float(max_dist)
|
|
56
|
+
self.bias = nn.Parameter(torch.zeros(self.num_heads, self.num_bins))
|
|
57
|
+
|
|
58
|
+
def _bin(self, dist: torch.Tensor) -> torch.Tensor:
|
|
59
|
+
# dist: (B,T,T) >= 0
|
|
60
|
+
d = dist.clamp(min=0.0, max=self.max_dist)
|
|
61
|
+
d = torch.log1p(d)
|
|
62
|
+
dmax = torch.log1p(torch.tensor(self.max_dist, device=d.device, dtype=d.dtype))
|
|
63
|
+
b = (d / dmax) * (self.num_bins - 1)
|
|
64
|
+
return b.to(torch.long)
|
|
65
|
+
|
|
66
|
+
def forward(self, pos: torch.Tensor) -> torch.Tensor:
|
|
67
|
+
# pos: (B,T)
|
|
68
|
+
dist = (pos[:, :, None] - pos[:, None, :]).abs() # (B,T,T)
|
|
69
|
+
bins = self._bin(dist) # (B,T,T)
|
|
70
|
+
# bias[:, bins] -> (H,B,T,T) then permute -> (B,H,T,T)
|
|
71
|
+
out = self.bias[:, bins]
|
|
72
|
+
return out.permute(1, 0, 2, 3).contiguous()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class TransformerBlock(nn.Module):
|
|
76
|
+
"""
|
|
77
|
+
Single pre-norm style block:
|
|
78
|
+
x -> MHA -> residual -> LN
|
|
79
|
+
-> FFN -> residual -> LN
|
|
80
|
+
|
|
81
|
+
Supports optional additive attention bias (e.g., relative position).
|
|
82
|
+
"""
|
|
83
|
+
def __init__(self, cfg: TransformerConfig):
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.cfg = cfg
|
|
86
|
+
d_model = int(cfg.d_model)
|
|
87
|
+
self.num_heads = int(cfg.num_heads)
|
|
88
|
+
|
|
89
|
+
self.attn = nn.MultiheadAttention(
|
|
90
|
+
embed_dim=d_model,
|
|
91
|
+
num_heads=self.num_heads,
|
|
92
|
+
dropout=float(cfg.attn_dropout),
|
|
93
|
+
batch_first=True,
|
|
94
|
+
)
|
|
95
|
+
self.attn_drop = nn.Dropout(float(cfg.dropout))
|
|
96
|
+
self.ln1 = nn.LayerNorm(d_model)
|
|
97
|
+
|
|
98
|
+
self.ff = nn.Sequential(
|
|
99
|
+
nn.Linear(d_model, int(cfg.dim_feedforward)),
|
|
100
|
+
nn.GELU() if str(cfg.activation).lower().strip() == "gelu" else nn.ReLU(),
|
|
101
|
+
nn.Dropout(float(cfg.dropout)),
|
|
102
|
+
nn.Linear(int(cfg.dim_feedforward), d_model),
|
|
103
|
+
)
|
|
104
|
+
self.ff_drop = nn.Dropout(float(cfg.dropout))
|
|
105
|
+
self.ln2 = nn.LayerNorm(d_model)
|
|
106
|
+
|
|
107
|
+
self.relpos: Optional[GenomicRelPosBias] = None
|
|
108
|
+
if bool(getattr(cfg, "use_relpos_bias", False)):
|
|
109
|
+
self.relpos = GenomicRelPosBias(
|
|
110
|
+
num_heads=self.num_heads,
|
|
111
|
+
num_bins=int(getattr(cfg, "relpos_num_bins", 32)),
|
|
112
|
+
max_dist=float(getattr(cfg, "relpos_max_dist", 1e6)),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def forward(
|
|
116
|
+
self,
|
|
117
|
+
x: torch.Tensor,
|
|
118
|
+
*,
|
|
119
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
120
|
+
token_pos: Optional[torch.Tensor] = None, # (B,T) basepairs or other coordinates
|
|
121
|
+
return_attn: bool = False,
|
|
122
|
+
attn_average_heads: bool = True,
|
|
123
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
124
|
+
need_weights = bool(return_attn)
|
|
125
|
+
|
|
126
|
+
attn_mask = None
|
|
127
|
+
if self.relpos is not None and token_pos is not None:
|
|
128
|
+
# (B,H,T,T) -> (B*H,T,T) for nn.MultiheadAttention
|
|
129
|
+
bias = self.relpos(token_pos).to(dtype=x.dtype)
|
|
130
|
+
B, H, T, _ = bias.shape
|
|
131
|
+
attn_mask = bias.view(B * H, T, T)
|
|
132
|
+
|
|
133
|
+
attn_out, attn_w = self.attn(
|
|
134
|
+
x, x, x,
|
|
135
|
+
key_padding_mask=key_padding_mask, # (B, T) True = PAD
|
|
136
|
+
attn_mask=attn_mask, # None or (B*H,T,T)
|
|
137
|
+
need_weights=need_weights,
|
|
138
|
+
average_attn_weights=bool(attn_average_heads),
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
x = self.ln1(x + self.attn_drop(attn_out))
|
|
142
|
+
ff_out = self.ff(x)
|
|
143
|
+
x = self.ln2(x + self.ff_drop(ff_out))
|
|
144
|
+
|
|
145
|
+
if return_attn:
|
|
146
|
+
if attn_w is None:
|
|
147
|
+
raise RuntimeError("Expected attn_w when return_attn=True, got None.")
|
|
148
|
+
return x, attn_w
|
|
149
|
+
return x
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class TransformerEncoder(nn.Module):
|
|
153
|
+
"""
|
|
154
|
+
Generic encoder:
|
|
155
|
+
tokens (B,T,D_in) -> proj -> blocks -> pool -> out_proj -> (B,d_out)
|
|
156
|
+
|
|
157
|
+
Optional:
|
|
158
|
+
- learned absolute positional embeddings (use_positional_encoding=True)
|
|
159
|
+
- relative attention bias via token_pos (if cfg.use_relpos_bias=True)
|
|
160
|
+
"""
|
|
161
|
+
def __init__(
|
|
162
|
+
self,
|
|
163
|
+
*,
|
|
164
|
+
cfg: TransformerConfig,
|
|
165
|
+
d_in: int,
|
|
166
|
+
d_out: int,
|
|
167
|
+
use_positional_encoding: bool = True,
|
|
168
|
+
):
|
|
169
|
+
super().__init__()
|
|
170
|
+
self.cfg = cfg
|
|
171
|
+
self.use_positional_encoding = bool(use_positional_encoding)
|
|
172
|
+
|
|
173
|
+
d_model = int(cfg.d_model)
|
|
174
|
+
self.input_proj = nn.Identity() if int(d_in) == d_model else nn.Linear(int(d_in), d_model, bias=True)
|
|
175
|
+
self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(int(cfg.num_layers))])
|
|
176
|
+
self.dropout = nn.Dropout(float(cfg.dropout))
|
|
177
|
+
self.out_proj = nn.Linear(d_model, int(d_out), bias=True)
|
|
178
|
+
|
|
179
|
+
self.pooling = str(cfg.pooling).lower().strip()
|
|
180
|
+
if self.pooling not in ("cls", "mean"):
|
|
181
|
+
raise ValueError(f"Unknown pooling={cfg.pooling!r}")
|
|
182
|
+
|
|
183
|
+
# learned positional embeddings (optional)
|
|
184
|
+
self.pos_emb: Optional[nn.Parameter] = None
|
|
185
|
+
if self.use_positional_encoding:
|
|
186
|
+
if cfg.max_tokens is None:
|
|
187
|
+
raise ValueError("use_positional_encoding=True requires cfg.max_tokens to be set.")
|
|
188
|
+
max_tokens = int(cfg.max_tokens)
|
|
189
|
+
self.pos_emb = nn.Parameter(torch.zeros(1, max_tokens, d_model))
|
|
190
|
+
nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
|
|
191
|
+
|
|
192
|
+
def _pool(self, x: torch.Tensor, *, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
|
193
|
+
if self.pooling == "cls":
|
|
194
|
+
return x[:, 0, :]
|
|
195
|
+
|
|
196
|
+
if key_padding_mask is None:
|
|
197
|
+
return x.mean(dim=1)
|
|
198
|
+
|
|
199
|
+
keep = (~key_padding_mask).to(dtype=x.dtype) # (B, T)
|
|
200
|
+
denom = keep.sum(dim=1, keepdim=True).clamp_min(1.0)
|
|
201
|
+
return (x * keep.unsqueeze(-1)).sum(dim=1) / denom
|
|
202
|
+
|
|
203
|
+
def forward(
|
|
204
|
+
self,
|
|
205
|
+
tokens: torch.Tensor,
|
|
206
|
+
*,
|
|
207
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
208
|
+
token_pos: Optional[torch.Tensor] = None, # (B,T) for relpos bias (optional)
|
|
209
|
+
return_attn: bool = False,
|
|
210
|
+
attn_average_heads: bool = True,
|
|
211
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
|
212
|
+
x = self.input_proj(tokens)
|
|
213
|
+
|
|
214
|
+
if self.use_positional_encoding:
|
|
215
|
+
assert self.pos_emb is not None
|
|
216
|
+
T = x.shape[1]
|
|
217
|
+
if T > self.pos_emb.shape[1]:
|
|
218
|
+
raise ValueError(f"Sequence length T={T} exceeds max_tokens={self.pos_emb.shape[1]}.")
|
|
219
|
+
x = x + self.pos_emb[:, :T, :]
|
|
220
|
+
|
|
221
|
+
x = self.dropout(x)
|
|
222
|
+
|
|
223
|
+
attn_all: List[torch.Tensor] = []
|
|
224
|
+
for blk in self.blocks:
|
|
225
|
+
if return_attn:
|
|
226
|
+
x, aw = blk(
|
|
227
|
+
x,
|
|
228
|
+
key_padding_mask=key_padding_mask,
|
|
229
|
+
token_pos=token_pos,
|
|
230
|
+
return_attn=True,
|
|
231
|
+
attn_average_heads=attn_average_heads,
|
|
232
|
+
)
|
|
233
|
+
attn_all.append(aw)
|
|
234
|
+
else:
|
|
235
|
+
x = blk(
|
|
236
|
+
x,
|
|
237
|
+
key_padding_mask=key_padding_mask,
|
|
238
|
+
token_pos=token_pos,
|
|
239
|
+
return_attn=False,
|
|
240
|
+
attn_average_heads=attn_average_heads,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
pooled = self._pool(x, key_padding_mask=key_padding_mask)
|
|
244
|
+
out = self.out_proj(pooled)
|
|
245
|
+
|
|
246
|
+
if return_attn:
|
|
247
|
+
return out, attn_all
|
|
248
|
+
return out
|
|
249
|
+
|