univi 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
univi/models/mlp.py ADDED
@@ -0,0 +1,36 @@
1
+ # univi/models/mlp.py
2
+
3
+ from __future__ import annotations
4
+ from typing import List, Optional
5
+ from torch import nn
6
+
7
+
8
+ def build_mlp(
9
+ in_dim: int,
10
+ hidden_dims: List[int],
11
+ out_dim: int,
12
+ activation: Optional[nn.Module] = None,
13
+ dropout: float = 0.0,
14
+ batchnorm: bool = True,
15
+ ) -> nn.Sequential:
16
+ """
17
+ Generic MLP builder: [Linear -> BN -> Act -> Dropout]* + final Linear.
18
+ (Python gotcha: don't use nn.ReLU() as a default arg; it becomes a shared instance.)
19
+ """
20
+ if activation is None:
21
+ activation = nn.ReLU()
22
+
23
+ layers = []
24
+ last_dim = in_dim
25
+ for h in hidden_dims:
26
+ layers.append(nn.Linear(last_dim, h))
27
+ if batchnorm:
28
+ layers.append(nn.BatchNorm1d(h))
29
+ layers.append(activation.__class__() if isinstance(activation, nn.Module) else nn.ReLU())
30
+ if dropout and dropout > 0:
31
+ layers.append(nn.Dropout(float(dropout)))
32
+ last_dim = h
33
+
34
+ layers.append(nn.Linear(last_dim, out_dim))
35
+ return nn.Sequential(*layers)
36
+
@@ -0,0 +1,376 @@
1
+ # univi/models/tokenizers.py
2
+ from __future__ import annotations
3
+
4
+ from typing import Optional, Tuple, Sequence, Literal, Dict, Any
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from ..config import TokenizerConfig
10
+
11
+
12
+ class Tokenizer(nn.Module):
13
+ """
14
+ Base tokenizer interface.
15
+
16
+ Backwards-compatible:
17
+ forward(x) -> (tokens, key_padding_mask)
18
+
19
+ Extras:
20
+ - self.last_meta is updated on each forward()
21
+ - forward_with_meta(x) -> (tokens, key_padding_mask, meta)
22
+
23
+ Conventions
24
+ -----------
25
+ - tokens: (B, T, D_in)
26
+ - key_padding_mask: Optional[(B, T)] where True means "PAD / ignore"
27
+ - meta: dict (optional), e.g. {"token_pos": (B, T) basepair positions}
28
+ """
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.last_meta: Dict[str, Any] = {}
32
+
33
+ @property
34
+ def d_in(self) -> int:
35
+ raise NotImplementedError
36
+
37
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
38
+ raise NotImplementedError
39
+
40
+ def forward_with_meta(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor], Dict[str, Any]]:
41
+ tokens, mask = self.forward(x)
42
+ return tokens, mask, dict(self.last_meta)
43
+
44
+
45
+ class TopKScalarTokenizer(Tokenizer):
46
+ """(B,F) -> (B,K,1) using top-k by absolute value per cell."""
47
+ def __init__(self, n_tokens: int, add_cls_token: bool = False):
48
+ super().__init__()
49
+ self.n_tokens = int(n_tokens)
50
+ self.add_cls_token = bool(add_cls_token)
51
+
52
+ @property
53
+ def d_in(self) -> int:
54
+ return 1
55
+
56
+ def forward(self, x: torch.Tensor):
57
+ B, F = x.shape
58
+ K = min(self.n_tokens, F)
59
+
60
+ _, idx = torch.topk(x.abs(), k=K, dim=1, largest=True, sorted=True)
61
+ vals = torch.gather(x, 1, idx) # (B,K)
62
+ tokens = vals.unsqueeze(-1) # (B,K,1)
63
+
64
+ key_padding_mask = None
65
+ self.last_meta = {"feature_idx": idx}
66
+
67
+ if self.add_cls_token:
68
+ cls = torch.zeros((B, 1, 1), device=x.device, dtype=x.dtype)
69
+ tokens = torch.cat([cls, tokens], dim=1)
70
+
71
+ return tokens, key_padding_mask
72
+
73
+
74
+ class TopKChannelsTokenizer(Tokenizer):
75
+ """
76
+ (B,F) -> (B,K,C) multi-dim tokens, where channels can include:
77
+ - value: raw x_i
78
+ - rank: rank within selected K (0..1)
79
+ - dropout: 1 if x_i == 0 else 0
80
+ """
81
+ def __init__(
82
+ self,
83
+ n_tokens: int,
84
+ channels: Sequence[Literal["value", "rank", "dropout"]] = ("value", "rank", "dropout"),
85
+ add_cls_token: bool = False,
86
+ ):
87
+ super().__init__()
88
+ self.n_tokens = int(n_tokens)
89
+ self.channels = tuple(channels)
90
+ self.add_cls_token = bool(add_cls_token)
91
+
92
+ if len(self.channels) == 0:
93
+ raise ValueError("TopKChannelsTokenizer requires at least one channel.")
94
+ for c in self.channels:
95
+ if c not in ("value", "rank", "dropout"):
96
+ raise ValueError(f"Unknown channel {c!r}. Allowed: value, rank, dropout")
97
+
98
+ @property
99
+ def d_in(self) -> int:
100
+ return len(self.channels)
101
+
102
+ def forward(self, x: torch.Tensor):
103
+ B, F = x.shape
104
+ K = min(self.n_tokens, F)
105
+
106
+ _, idx = torch.topk(x.abs(), k=K, dim=1, largest=True, sorted=True)
107
+ vals = torch.gather(x, 1, idx) # (B,K)
108
+
109
+ chans = []
110
+ for c in self.channels:
111
+ if c == "value":
112
+ chans.append(vals)
113
+ elif c == "dropout":
114
+ chans.append((vals == 0).to(vals.dtype))
115
+ elif c == "rank":
116
+ r = torch.arange(K, device=x.device, dtype=vals.dtype).view(1, K).expand(B, K)
117
+ chans.append(r / max(K - 1, 1))
118
+ else:
119
+ raise RuntimeError("unreachable")
120
+
121
+ tokens = torch.stack(chans, dim=-1) # (B,K,C)
122
+ key_padding_mask = None
123
+ self.last_meta = {"feature_idx": idx}
124
+
125
+ if self.add_cls_token:
126
+ cls = torch.zeros((B, 1, tokens.size(-1)), device=x.device, dtype=x.dtype)
127
+ tokens = torch.cat([cls, tokens], dim=1)
128
+
129
+ return tokens, key_padding_mask
130
+
131
+
132
+ class PatchTokenizer(Tokenizer):
133
+ """
134
+ Split features into patches:
135
+
136
+ (B,F) -> (B,T,patch_size) where T = ceil(F/patch_size)
137
+
138
+ Optionally project:
139
+ patch_vec (patch_size) -> patch_proj_dim
140
+ """
141
+ def __init__(
142
+ self,
143
+ patch_size: int,
144
+ add_cls_token: bool = False,
145
+ patch_proj_dim: Optional[int] = None,
146
+ ):
147
+ super().__init__()
148
+ self.patch_size = int(patch_size)
149
+ self.add_cls_token = bool(add_cls_token)
150
+ self.patch_proj_dim = int(patch_proj_dim) if patch_proj_dim is not None else None
151
+
152
+ if self.patch_size <= 0:
153
+ raise ValueError("patch_size must be > 0")
154
+
155
+ if self.patch_proj_dim is not None:
156
+ self.proj = nn.Sequential(
157
+ nn.LayerNorm(self.patch_size),
158
+ nn.Linear(self.patch_size, self.patch_proj_dim),
159
+ nn.GELU(),
160
+ nn.Linear(self.patch_proj_dim, self.patch_proj_dim),
161
+ )
162
+ else:
163
+ self.proj = None
164
+
165
+ @property
166
+ def d_in(self) -> int:
167
+ return self.patch_proj_dim if self.patch_proj_dim is not None else self.patch_size
168
+
169
+ def forward(self, x: torch.Tensor):
170
+ B, F = x.shape
171
+ P = self.patch_size
172
+ T = (F + P - 1) // P
173
+ pad = T * P - F
174
+
175
+ if pad > 0:
176
+ x_pad = torch.cat([x, torch.zeros((B, pad), device=x.device, dtype=x.dtype)], dim=1)
177
+ else:
178
+ x_pad = x
179
+
180
+ patches = x_pad.view(B, T, P) # (B,T,P)
181
+ key_padding_mask = None
182
+
183
+ if self.proj is not None:
184
+ patches = self.proj(patches) # (B,T,patch_proj_dim)
185
+
186
+ if self.add_cls_token:
187
+ cls = torch.zeros((B, 1, patches.size(-1)), device=x.device, dtype=x.dtype)
188
+ patches = torch.cat([cls, patches], dim=1)
189
+
190
+ self.last_meta = {}
191
+ return patches, key_padding_mask
192
+
193
+
194
+ class TopKEmbeddedTokenizer(Tokenizer):
195
+ """
196
+ Top-k tokenizer with explicit feature identity embeddings:
197
+
198
+ token = Emb(feature_id) + MLP(channels(value/rank/dropout))
199
+
200
+ Optional ATAC coordinate embeddings:
201
+ token += Emb(chrom_id) + MLP(midpoint_bp / coord_scale)
202
+
203
+ Meta
204
+ ----
205
+ self.last_meta will include:
206
+ - "feature_idx": (B,K) long
207
+ - "token_pos": (B,K) float basepairs (if use_coords=True)
208
+ """
209
+ def __init__(
210
+ self,
211
+ *,
212
+ n_tokens: int,
213
+ n_features: int,
214
+ d_model: int,
215
+ channels: Sequence[Literal["value", "rank", "dropout"]] = ("value", "rank", "dropout"),
216
+ add_cls_token: bool = False,
217
+ value_mlp_hidden: int = 256,
218
+ # coordinate extras
219
+ use_coords: bool = False,
220
+ chrom_vocab_size: int = 0,
221
+ feature_info: Optional[Dict[str, Any]] = None,
222
+ coord_scale: float = 1e6,
223
+ ):
224
+ super().__init__()
225
+ self.n_tokens = int(n_tokens)
226
+ self.n_features = int(n_features)
227
+ self._d_model = int(d_model)
228
+ self.channels = tuple(channels)
229
+ self.add_cls_token = bool(add_cls_token)
230
+ self.use_coords = bool(use_coords)
231
+ self.chrom_vocab_size = int(chrom_vocab_size)
232
+ self.coord_scale = float(coord_scale)
233
+
234
+ if len(self.channels) == 0:
235
+ raise ValueError("TopKEmbeddedTokenizer requires at least one channel.")
236
+ for c in self.channels:
237
+ if c not in ("value", "rank", "dropout"):
238
+ raise ValueError(f"Unknown channel {c!r}. Allowed: value, rank, dropout")
239
+
240
+ self.id_embed = nn.Embedding(self.n_features, self._d_model)
241
+
242
+ c_in = len(self.channels)
243
+ self.val_proj = nn.Sequential(
244
+ nn.LayerNorm(c_in),
245
+ nn.Linear(c_in, int(value_mlp_hidden)),
246
+ nn.GELU(),
247
+ nn.Linear(int(value_mlp_hidden), self._d_model),
248
+ )
249
+
250
+ # Feature metadata buffers for coords
251
+ self.feature_chrom: Optional[torch.Tensor] = None
252
+ self.feature_start: Optional[torch.Tensor] = None
253
+ self.feature_end: Optional[torch.Tensor] = None
254
+
255
+ if self.use_coords:
256
+ if self.chrom_vocab_size <= 0:
257
+ raise ValueError("chrom_vocab_size must be > 0 when use_coords=True.")
258
+ if feature_info is None:
259
+ raise ValueError("feature_info must be provided when use_coords=True (keys: chrom,start,end).")
260
+ for k in ("chrom", "start", "end"):
261
+ if k not in feature_info:
262
+ raise ValueError(f"feature_info missing key {k!r} (required for coords).")
263
+
264
+ chrom = torch.as_tensor(feature_info["chrom"], dtype=torch.long)
265
+ start = torch.as_tensor(feature_info["start"], dtype=torch.float32)
266
+ end = torch.as_tensor(feature_info["end"], dtype=torch.float32)
267
+
268
+ if chrom.numel() != self.n_features or start.numel() != self.n_features or end.numel() != self.n_features:
269
+ raise ValueError(
270
+ f"feature_info arrays must have length n_features={self.n_features}; "
271
+ f"got chrom={chrom.numel()}, start={start.numel()}, end={end.numel()}."
272
+ )
273
+
274
+ # register buffers so they follow .to(device)
275
+ self.register_buffer("feature_chrom", chrom, persistent=False)
276
+ self.register_buffer("feature_start", start, persistent=False)
277
+ self.register_buffer("feature_end", end, persistent=False)
278
+
279
+ self.chrom_embed = nn.Embedding(self.chrom_vocab_size, self._d_model)
280
+ self.coord_mlp = nn.Sequential(
281
+ nn.LayerNorm(1),
282
+ nn.Linear(1, int(value_mlp_hidden)),
283
+ nn.GELU(),
284
+ nn.Linear(int(value_mlp_hidden), self._d_model),
285
+ )
286
+
287
+ @property
288
+ def d_in(self) -> int:
289
+ return self._d_model
290
+
291
+ def forward(self, x: torch.Tensor):
292
+ B, F = x.shape
293
+ if F != self.n_features:
294
+ raise ValueError(f"Expected F={self.n_features}, got {F}. Did you set TokenizerConfig.n_features correctly?")
295
+
296
+ K = min(self.n_tokens, F)
297
+
298
+ _, idx = torch.topk(x.abs(), k=K, dim=1, largest=True, sorted=True) # (B,K)
299
+ vals = torch.gather(x, 1, idx) # (B,K)
300
+
301
+ # channels -> (B,K,C)
302
+ chans = []
303
+ for c in self.channels:
304
+ if c == "value":
305
+ chans.append(vals)
306
+ elif c == "dropout":
307
+ chans.append((vals == 0).to(vals.dtype))
308
+ elif c == "rank":
309
+ r = torch.arange(K, device=x.device, dtype=vals.dtype).view(1, K).expand(B, K)
310
+ chans.append(r / max(K - 1, 1))
311
+ else:
312
+ raise RuntimeError("unreachable")
313
+
314
+ ch = torch.stack(chans, dim=-1) # (B,K,C)
315
+
316
+ id_emb = self.id_embed(idx) # (B,K,D)
317
+ val_emb = self.val_proj(ch) # (B,K,D)
318
+ tokens = id_emb + val_emb
319
+
320
+ meta: Dict[str, Any] = {"feature_idx": idx}
321
+
322
+ if self.use_coords:
323
+ # buffers exist because we register_buffer above
324
+ chrom = self.feature_chrom[idx] # (B,K)
325
+ mid = 0.5 * (self.feature_start[idx] + self.feature_end[idx]) # (B,K)
326
+ mid_scaled = (mid / self.coord_scale).unsqueeze(-1) # (B,K,1)
327
+
328
+ tokens = tokens + self.chrom_embed(chrom) + self.coord_mlp(mid_scaled)
329
+ meta["token_pos"] = mid # basepairs
330
+
331
+ if self.add_cls_token:
332
+ cls = torch.zeros((B, 1, tokens.size(-1)), device=x.device, dtype=x.dtype)
333
+ tokens = torch.cat([cls, tokens], dim=1)
334
+ # keep meta aligned if present
335
+ if "token_pos" in meta:
336
+ cls_pos = torch.zeros((B, 1), device=x.device, dtype=meta["token_pos"].dtype)
337
+ meta["token_pos"] = torch.cat([cls_pos, meta["token_pos"]], dim=1)
338
+
339
+ self.last_meta = meta
340
+ return tokens, None
341
+
342
+
343
+ def build_tokenizer(cfg: TokenizerConfig) -> Tokenizer:
344
+ mode = (cfg.mode or "").lower().strip()
345
+
346
+ if mode == "topk_scalar":
347
+ return TopKScalarTokenizer(n_tokens=cfg.n_tokens, add_cls_token=cfg.add_cls_token)
348
+
349
+ if mode == "topk_channels":
350
+ return TopKChannelsTokenizer(n_tokens=cfg.n_tokens, channels=cfg.channels, add_cls_token=cfg.add_cls_token)
351
+
352
+ if mode == "patch":
353
+ return PatchTokenizer(
354
+ patch_size=cfg.patch_size,
355
+ add_cls_token=cfg.add_cls_token,
356
+ patch_proj_dim=cfg.patch_proj_dim,
357
+ )
358
+
359
+ if mode == "topk_embed":
360
+ if cfg.n_features is None or cfg.d_model is None:
361
+ raise ValueError("TokenizerConfig.mode='topk_embed' requires n_features and d_model to be set.")
362
+ return TopKEmbeddedTokenizer(
363
+ n_tokens=cfg.n_tokens,
364
+ n_features=int(cfg.n_features),
365
+ d_model=int(cfg.d_model),
366
+ channels=cfg.channels,
367
+ add_cls_token=cfg.add_cls_token,
368
+ value_mlp_hidden=int(cfg.value_mlp_hidden),
369
+ use_coords=bool(cfg.use_coords),
370
+ chrom_vocab_size=int(cfg.chrom_vocab_size),
371
+ feature_info=cfg.feature_info,
372
+ coord_scale=float(cfg.coord_scale),
373
+ )
374
+
375
+ raise ValueError(f"Unknown tokenizer mode {cfg.mode!r}")
376
+
@@ -0,0 +1,249 @@
1
+ # univi/models/transformer.py
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Literal, List, Tuple, Union
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ @dataclass
13
+ class TransformerConfig:
14
+ d_model: int
15
+ num_heads: int
16
+ num_layers: int
17
+ dim_feedforward: int = 4096
18
+ dropout: float = 0.1
19
+ attn_dropout: float = 0.1
20
+ activation: Literal["relu", "gelu"] = "gelu"
21
+ pooling: Literal["cls", "mean"] = "mean"
22
+ max_tokens: Optional[int] = None
23
+
24
+ # Optional: binned relative-position attention bias (e.g., genomic distance)
25
+ use_relpos_bias: bool = False
26
+ relpos_num_bins: int = 32
27
+ relpos_max_dist: float = 1e6 # basepairs
28
+
29
+
30
+ def _act(name: str):
31
+ name = str(name).lower().strip()
32
+ if name == "relu":
33
+ return F.relu
34
+ if name == "gelu":
35
+ return F.gelu
36
+ raise ValueError(f"Unknown activation: {name!r}")
37
+
38
+
39
+ class GenomicRelPosBias(nn.Module):
40
+ """
41
+ Simple distance-binned relative attention bias.
42
+
43
+ Given token positions pos (B,T) in basepairs, returns an additive bias
44
+ (B, H, T, T). Intended for ATAC peak midpoints.
45
+
46
+ Notes
47
+ -----
48
+ - Uses log1p compression to allocate more bins to shorter distances.
49
+ - Bias table is learned: (H, num_bins).
50
+ """
51
+ def __init__(self, num_heads: int, num_bins: int = 32, max_dist: float = 1e6):
52
+ super().__init__()
53
+ self.num_heads = int(num_heads)
54
+ self.num_bins = int(num_bins)
55
+ self.max_dist = float(max_dist)
56
+ self.bias = nn.Parameter(torch.zeros(self.num_heads, self.num_bins))
57
+
58
+ def _bin(self, dist: torch.Tensor) -> torch.Tensor:
59
+ # dist: (B,T,T) >= 0
60
+ d = dist.clamp(min=0.0, max=self.max_dist)
61
+ d = torch.log1p(d)
62
+ dmax = torch.log1p(torch.tensor(self.max_dist, device=d.device, dtype=d.dtype))
63
+ b = (d / dmax) * (self.num_bins - 1)
64
+ return b.to(torch.long)
65
+
66
+ def forward(self, pos: torch.Tensor) -> torch.Tensor:
67
+ # pos: (B,T)
68
+ dist = (pos[:, :, None] - pos[:, None, :]).abs() # (B,T,T)
69
+ bins = self._bin(dist) # (B,T,T)
70
+ # bias[:, bins] -> (H,B,T,T) then permute -> (B,H,T,T)
71
+ out = self.bias[:, bins]
72
+ return out.permute(1, 0, 2, 3).contiguous()
73
+
74
+
75
+ class TransformerBlock(nn.Module):
76
+ """
77
+ Single pre-norm style block:
78
+ x -> MHA -> residual -> LN
79
+ -> FFN -> residual -> LN
80
+
81
+ Supports optional additive attention bias (e.g., relative position).
82
+ """
83
+ def __init__(self, cfg: TransformerConfig):
84
+ super().__init__()
85
+ self.cfg = cfg
86
+ d_model = int(cfg.d_model)
87
+ self.num_heads = int(cfg.num_heads)
88
+
89
+ self.attn = nn.MultiheadAttention(
90
+ embed_dim=d_model,
91
+ num_heads=self.num_heads,
92
+ dropout=float(cfg.attn_dropout),
93
+ batch_first=True,
94
+ )
95
+ self.attn_drop = nn.Dropout(float(cfg.dropout))
96
+ self.ln1 = nn.LayerNorm(d_model)
97
+
98
+ self.ff = nn.Sequential(
99
+ nn.Linear(d_model, int(cfg.dim_feedforward)),
100
+ nn.GELU() if str(cfg.activation).lower().strip() == "gelu" else nn.ReLU(),
101
+ nn.Dropout(float(cfg.dropout)),
102
+ nn.Linear(int(cfg.dim_feedforward), d_model),
103
+ )
104
+ self.ff_drop = nn.Dropout(float(cfg.dropout))
105
+ self.ln2 = nn.LayerNorm(d_model)
106
+
107
+ self.relpos: Optional[GenomicRelPosBias] = None
108
+ if bool(getattr(cfg, "use_relpos_bias", False)):
109
+ self.relpos = GenomicRelPosBias(
110
+ num_heads=self.num_heads,
111
+ num_bins=int(getattr(cfg, "relpos_num_bins", 32)),
112
+ max_dist=float(getattr(cfg, "relpos_max_dist", 1e6)),
113
+ )
114
+
115
+ def forward(
116
+ self,
117
+ x: torch.Tensor,
118
+ *,
119
+ key_padding_mask: Optional[torch.Tensor] = None,
120
+ token_pos: Optional[torch.Tensor] = None, # (B,T) basepairs or other coordinates
121
+ return_attn: bool = False,
122
+ attn_average_heads: bool = True,
123
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
124
+ need_weights = bool(return_attn)
125
+
126
+ attn_mask = None
127
+ if self.relpos is not None and token_pos is not None:
128
+ # (B,H,T,T) -> (B*H,T,T) for nn.MultiheadAttention
129
+ bias = self.relpos(token_pos).to(dtype=x.dtype)
130
+ B, H, T, _ = bias.shape
131
+ attn_mask = bias.view(B * H, T, T)
132
+
133
+ attn_out, attn_w = self.attn(
134
+ x, x, x,
135
+ key_padding_mask=key_padding_mask, # (B, T) True = PAD
136
+ attn_mask=attn_mask, # None or (B*H,T,T)
137
+ need_weights=need_weights,
138
+ average_attn_weights=bool(attn_average_heads),
139
+ )
140
+
141
+ x = self.ln1(x + self.attn_drop(attn_out))
142
+ ff_out = self.ff(x)
143
+ x = self.ln2(x + self.ff_drop(ff_out))
144
+
145
+ if return_attn:
146
+ if attn_w is None:
147
+ raise RuntimeError("Expected attn_w when return_attn=True, got None.")
148
+ return x, attn_w
149
+ return x
150
+
151
+
152
+ class TransformerEncoder(nn.Module):
153
+ """
154
+ Generic encoder:
155
+ tokens (B,T,D_in) -> proj -> blocks -> pool -> out_proj -> (B,d_out)
156
+
157
+ Optional:
158
+ - learned absolute positional embeddings (use_positional_encoding=True)
159
+ - relative attention bias via token_pos (if cfg.use_relpos_bias=True)
160
+ """
161
+ def __init__(
162
+ self,
163
+ *,
164
+ cfg: TransformerConfig,
165
+ d_in: int,
166
+ d_out: int,
167
+ use_positional_encoding: bool = True,
168
+ ):
169
+ super().__init__()
170
+ self.cfg = cfg
171
+ self.use_positional_encoding = bool(use_positional_encoding)
172
+
173
+ d_model = int(cfg.d_model)
174
+ self.input_proj = nn.Identity() if int(d_in) == d_model else nn.Linear(int(d_in), d_model, bias=True)
175
+ self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(int(cfg.num_layers))])
176
+ self.dropout = nn.Dropout(float(cfg.dropout))
177
+ self.out_proj = nn.Linear(d_model, int(d_out), bias=True)
178
+
179
+ self.pooling = str(cfg.pooling).lower().strip()
180
+ if self.pooling not in ("cls", "mean"):
181
+ raise ValueError(f"Unknown pooling={cfg.pooling!r}")
182
+
183
+ # learned positional embeddings (optional)
184
+ self.pos_emb: Optional[nn.Parameter] = None
185
+ if self.use_positional_encoding:
186
+ if cfg.max_tokens is None:
187
+ raise ValueError("use_positional_encoding=True requires cfg.max_tokens to be set.")
188
+ max_tokens = int(cfg.max_tokens)
189
+ self.pos_emb = nn.Parameter(torch.zeros(1, max_tokens, d_model))
190
+ nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
191
+
192
+ def _pool(self, x: torch.Tensor, *, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
193
+ if self.pooling == "cls":
194
+ return x[:, 0, :]
195
+
196
+ if key_padding_mask is None:
197
+ return x.mean(dim=1)
198
+
199
+ keep = (~key_padding_mask).to(dtype=x.dtype) # (B, T)
200
+ denom = keep.sum(dim=1, keepdim=True).clamp_min(1.0)
201
+ return (x * keep.unsqueeze(-1)).sum(dim=1) / denom
202
+
203
+ def forward(
204
+ self,
205
+ tokens: torch.Tensor,
206
+ *,
207
+ key_padding_mask: Optional[torch.Tensor] = None,
208
+ token_pos: Optional[torch.Tensor] = None, # (B,T) for relpos bias (optional)
209
+ return_attn: bool = False,
210
+ attn_average_heads: bool = True,
211
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
212
+ x = self.input_proj(tokens)
213
+
214
+ if self.use_positional_encoding:
215
+ assert self.pos_emb is not None
216
+ T = x.shape[1]
217
+ if T > self.pos_emb.shape[1]:
218
+ raise ValueError(f"Sequence length T={T} exceeds max_tokens={self.pos_emb.shape[1]}.")
219
+ x = x + self.pos_emb[:, :T, :]
220
+
221
+ x = self.dropout(x)
222
+
223
+ attn_all: List[torch.Tensor] = []
224
+ for blk in self.blocks:
225
+ if return_attn:
226
+ x, aw = blk(
227
+ x,
228
+ key_padding_mask=key_padding_mask,
229
+ token_pos=token_pos,
230
+ return_attn=True,
231
+ attn_average_heads=attn_average_heads,
232
+ )
233
+ attn_all.append(aw)
234
+ else:
235
+ x = blk(
236
+ x,
237
+ key_padding_mask=key_padding_mask,
238
+ token_pos=token_pos,
239
+ return_attn=False,
240
+ attn_average_heads=attn_average_heads,
241
+ )
242
+
243
+ pooled = self._pool(x, key_padding_mask=key_padding_mask)
244
+ out = self.out_proj(pooled)
245
+
246
+ if return_attn:
247
+ return out, attn_all
248
+ return out
249
+