broccoli-ml 9.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,779 @@
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.utils.checkpoint import checkpoint
8
+
9
+ from einops import rearrange
10
+
11
+ from .rope import RotaryEmbedding, apply_rotary_emb
12
+
13
+ try:
14
+ from flash_attn import flash_attn_func
15
+
16
+ print("Using flash-attn.")
17
+ FLASH_ATTN = True
18
+ except ImportError:
19
+ pass
20
+ FLASH_ATTN = False
21
+
22
+
23
+ def drop_path(
24
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
25
+ ):
26
+ """
27
+ From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
28
+ Copyright 2019 Ross Wightman
29
+ See documentation and licence there.
30
+ """
31
+ if drop_prob == 0.0 or not training:
32
+ return x
33
+ keep_prob = 1 - drop_prob
34
+ shape = (x.shape[0],) + (1,) * (
35
+ x.ndim - 1
36
+ ) # work with diff dim tensors, not just 2D ConvNets
37
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
38
+ if keep_prob > 0.0 and scale_by_keep:
39
+ random_tensor.div_(keep_prob)
40
+ return x * random_tensor
41
+
42
+
43
+ class DropPath(nn.Module):
44
+ """
45
+ From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
46
+ Copyright 2019 Ross Wightman
47
+ See documentation and licence there.
48
+ """
49
+
50
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
51
+ super(DropPath, self).__init__()
52
+ self.drop_prob = drop_prob
53
+ self.scale_by_keep = scale_by_keep
54
+
55
+ def forward(self, x):
56
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
57
+
58
+ def extra_repr(self):
59
+ return f"drop_prob={round(self.drop_prob, 3):0.3f}"
60
+
61
+
62
+ class MHAttention(nn.Module):
63
+ """
64
+ Multi-head self-attention using einops and optionally a custom linear layer.
65
+
66
+ Forward method assumes q, k and v have the same embedding size and k and v
67
+ are the same shape.
68
+
69
+ Assumes bias=False and batch_first=True, as God intended.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ embed_dim,
75
+ n_heads,
76
+ dropout=0.0,
77
+ causal=False,
78
+ seq_len=None,
79
+ linear_module: nn.Module = nn.Linear,
80
+ utility_tokens=0,
81
+ rotary_embedding=None,
82
+ source_size=None,
83
+ scaling="d",
84
+ ):
85
+ """
86
+ Args:
87
+ scaling: how should the attention logits be scaled? Can be "sqrtd"
88
+ to mimic the original Attention is All You Need approach of
89
+ dividing by the sqrt of the embedding Dimension or "d" per
90
+ "Tensor Programs V...". Default "d"
91
+ """
92
+ super().__init__()
93
+
94
+ if rotary_embedding is not None:
95
+ assert source_size is not None
96
+ if causal:
97
+ assert seq_len is not None
98
+
99
+ self.embed_dim = embed_dim
100
+ self.n_heads = n_heads
101
+ assert embed_dim % n_heads == 0
102
+ self.scaling = scaling
103
+
104
+ self.head_dim = self.embed_dim // self.n_heads
105
+
106
+ if self.scaling == "sqrtd":
107
+ self.scaling_factor = 1 / math.sqrt(self.head_dim)
108
+ elif self.scaling == "d":
109
+ # 8/d_model for backwards compatibility,
110
+ # per https://github.com/microsoft/mup
111
+ self.scaling_factor = 8 / self.head_dim
112
+ else:
113
+ raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
114
+
115
+ self.q_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
116
+ self.k_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
117
+ self.v_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
118
+
119
+ self.out_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
120
+
121
+ self.causal = causal
122
+ self.seq_len = seq_len
123
+ self.dropout = nn.Dropout(dropout)
124
+ if self.causal:
125
+ self.register_buffer(
126
+ "mask",
127
+ (torch.triu(torch.ones(seq_len, seq_len), diagonal=1) == 1)
128
+ .unsqueeze(0)
129
+ .unsqueeze(0),
130
+ )
131
+ self.rotary_embedding = rotary_embedding
132
+ self.source_size = source_size
133
+ self.utility_tokens = utility_tokens
134
+
135
+ self.reset_parameters()
136
+
137
+ @property
138
+ def _kv_distance(self) -> float:
139
+ """
140
+ Calculates the cosine distance between the weight tensors of `self.k_proj`
141
+ and `self.v_proj`.
142
+
143
+ The cosine distance is defined as 1 - cosine_similarity (i.e. a value
144
+ closer to 0 indicates higher similarity.
145
+ """
146
+
147
+ similarity = F.cosine_similarity(
148
+ self.k_proj.weight.detach().flatten(),
149
+ self.v_proj.weight.detach().flatten(),
150
+ dim=0,
151
+ eps=1e-8,
152
+ ).item()
153
+
154
+ return 1 - similarity
155
+
156
+ def add_axial_rope(
157
+ self, q: torch.Tensor, k: torch.Tensor
158
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
159
+ """
160
+ Apply Axial RoPE to all tokens except utility tokens
161
+ """
162
+
163
+ if len(self.source_size) == 1:
164
+ spatial_dimension_names = "D1"
165
+ spatial_dimension_values = {"D1": self.source_size[0]}
166
+ elif len(self.source_size) == 2:
167
+ spatial_dimension_names = "D1 D2"
168
+ spatial_dimension_values = {
169
+ "D1": self.source_size[0],
170
+ "D2": self.source_size[1],
171
+ }
172
+ elif len(self.source_size) == 3:
173
+ spatial_dimension_names = "D1 D2 D3"
174
+ spatial_dimension_values = {
175
+ "D1": self.source_size[0],
176
+ "D2": self.source_size[1],
177
+ "D3": self.source_size[2],
178
+ }
179
+ else:
180
+ raise NotImplementedError(
181
+ "`source_size` must be a tuple of 1, 2 or 3 integers"
182
+ )
183
+
184
+ q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
185
+ k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
186
+
187
+ q_img = rearrange(
188
+ q_img,
189
+ f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
190
+ **spatial_dimension_values,
191
+ )
192
+ k_img = rearrange(
193
+ k_img,
194
+ f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
195
+ **spatial_dimension_values,
196
+ )
197
+
198
+ freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
199
+
200
+ q_img = apply_rotary_emb(freqs, q_img)
201
+ k_img = apply_rotary_emb(freqs, k_img)
202
+
203
+ q_img = rearrange(
204
+ q_img,
205
+ f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
206
+ )
207
+ k_img = rearrange(
208
+ k_img,
209
+ f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
210
+ )
211
+
212
+ # Re-combine the utility tokens and the RoPE-enhanced sequence tokens
213
+ q = torch.cat([q_util, q_img], dim=1)
214
+ k = torch.cat([k_util, k_img], dim=1)
215
+
216
+ return q, k
217
+
218
+ def project_qkv(
219
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
220
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
221
+ query_batch_size, query_tokens, query_features = q.size()
222
+ key_batch_size, key_tokens, key_features = k.size()
223
+
224
+ assert k.size() == v.size()
225
+ assert query_features == key_features
226
+ assert (
227
+ (query_batch_size == key_batch_size) # batch sizes are the same...
228
+ or query_batch_size == 1 # ... or query is broadcastable
229
+ )
230
+
231
+ if self.causal:
232
+ assert query_tokens == key_tokens
233
+ assert query_tokens == self.seq_len
234
+
235
+ q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
236
+
237
+ if self.rotary_embedding is not None:
238
+ q, k = self.add_axial_rope(q, k)
239
+
240
+ return q, k, v
241
+
242
+ def forward(self, q, k, v):
243
+
244
+ q, k, v = self.project_qkv(q, k, v)
245
+
246
+ if FLASH_ATTN:
247
+ # Divide Q/K/V into heads
248
+ q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
249
+ k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
250
+ v = rearrange(v, "b t (h d) -> b t h d", h=self.n_heads)
251
+
252
+ output_with_heads = flash_attn_func(
253
+ q,
254
+ k,
255
+ v,
256
+ dropout_p=self.dropout.p if self.training else 0.0,
257
+ softmax_scale=self.scaling_factor,
258
+ causal=self.causal,
259
+ )
260
+
261
+ output_without_heads = rearrange(output_with_heads, "b t h d -> b t (h d)")
262
+
263
+ return self.out_proj(output_without_heads)
264
+ else:
265
+ # Divide Q/K/V into heads
266
+ q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
267
+ k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
268
+ v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
269
+
270
+ qk_scores = q @ k.transpose(-1, -2)
271
+
272
+ qk_scores *= self.scaling_factor
273
+
274
+ # Apply mask if causal (must come before softmax)
275
+ if self.causal:
276
+ qk_scores.masked_fill_(self.mask, float("-inf"))
277
+
278
+ qk_scores = F.softmax(qk_scores, dim=-1)
279
+
280
+ qk_scores = self.dropout(qk_scores)
281
+
282
+ output_with_heads = qk_scores @ v
283
+
284
+ output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
285
+
286
+ return self.out_proj(output_without_heads)
287
+
288
+ def attention_logits(self, q, k, v):
289
+
290
+ q, k, v = self.project_qkv(q, k, v)
291
+
292
+ # Divide Q/K/V into heads
293
+ q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
294
+ k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
295
+ v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
296
+
297
+ qk_scores = q @ k.transpose(-1, -2)
298
+
299
+ qk_scores *= self.scaling_factor
300
+
301
+ # Apply mask if causal (must come before softmax)
302
+ if self.causal:
303
+ qk_scores.masked_fill_(self.mask, float("-inf"))
304
+
305
+ return qk_scores # (batch, head, seq_len, seq_len)
306
+
307
+ def reset_parameters(self):
308
+ # Default nn.Linear init is kaiming_uniform, which is fine
309
+ self.q_proj.reset_parameters()
310
+ self.k_proj.reset_parameters()
311
+ self.v_proj.reset_parameters()
312
+ self.out_proj.reset_parameters()
313
+
314
+
315
+ class FeedforwardBlock(nn.Module):
316
+ """
317
+ ...
318
+ """
319
+
320
+ def __init__(
321
+ self,
322
+ input_features,
323
+ ratio,
324
+ output_features,
325
+ activation=nn.ReLU,
326
+ activation_kwargs=None,
327
+ dropout=0.0,
328
+ inner_dropout=None,
329
+ outer_dropout=None,
330
+ linear_module_up=nn.Linear,
331
+ linear_module_down=nn.Linear,
332
+ pre_norm=True,
333
+ normformer=False,
334
+ post_norm=True,
335
+ residual_path=True,
336
+ checkpoint=True,
337
+ ):
338
+ super().__init__()
339
+
340
+ self.checkpoint = checkpoint
341
+ self.residual_path = residual_path
342
+ self.post_norm = post_norm
343
+ self.xglu = activation.__name__.endswith("GLU")
344
+
345
+ if self.residual_path and (output_features < input_features):
346
+ raise ValueError(
347
+ "If the number of output features will be less than "
348
+ "the number of input features, then `residual_path` "
349
+ "should be set to False."
350
+ )
351
+
352
+ if self.post_norm:
353
+ self.layernorm = nn.LayerNorm(output_features)
354
+
355
+ if activation_kwargs is not None:
356
+ self.activation = activation(**activation_kwargs)
357
+ else:
358
+ self.activation = activation()
359
+
360
+ self.inner_dropout = nn.Dropout(
361
+ inner_dropout if inner_dropout is not None else dropout
362
+ )
363
+ self.outer_dropout = nn.Dropout(
364
+ outer_dropout if outer_dropout is not None else dropout
365
+ )
366
+
367
+ self.max_features = (
368
+ 2 * ratio * output_features if self.xglu else ratio * output_features
369
+ )
370
+
371
+ self.linear_in = linear_module_up(input_features, self.max_features)
372
+ self.linear_out = linear_module_down(ratio * output_features, output_features)
373
+
374
+ self.process = nn.Sequential(
375
+ *[
376
+ nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
377
+ self.linear_in,
378
+ self.activation,
379
+ self.inner_dropout,
380
+ nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
381
+ self.linear_out,
382
+ self.outer_dropout,
383
+ ]
384
+ )
385
+
386
+ self.recycling_enabled = False
387
+ if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
388
+ self.linear_out, "column_recycling_rate"
389
+ ):
390
+ self.recycling_enabled = True
391
+ self.master_recycling_rate = self.linear_in.row_recycling_rate
392
+ self.linear_in.row_recycling_rate = 0.0
393
+ self.linear_out.column_recycling_rate = 0.0
394
+ if (
395
+ hasattr(self.linear_in, "column_recycling_rate")
396
+ and self.linear_in.column_recycling_rate > 0
397
+ ) or (
398
+ hasattr(self.linear_out, "row_recycling_rate")
399
+ and self.linear_out.row_recycling_rate > 0
400
+ ):
401
+ raise NotImplementedError(
402
+ "At the moment this layer can only support recycling linear "
403
+ "layers if the in layer resets only rows and the out layer "
404
+ "resets only columns."
405
+ )
406
+
407
+ self.reset_parameters()
408
+
409
+ def forward(self, x):
410
+
411
+ # Recycle weights if using recycling linear layers
412
+ if self.training and self.recycling_enabled:
413
+ indices = self.linear_out.get_reset_indices(1)
414
+ self.linear_in.reset_rows(indices)
415
+ self.linear_out.reset_columns(indices)
416
+
417
+ if self.checkpoint:
418
+ processed = checkpoint(self.process, x, use_reentrant=False)
419
+ else:
420
+ processed = self.process(x)
421
+
422
+ if self.residual_path and self.post_norm:
423
+ return self.layernorm(x + processed)
424
+ elif self.residual_path:
425
+ return x + processed
426
+ else:
427
+ return processed
428
+
429
+ def reset_parameters(self):
430
+ if self.post_norm:
431
+ self.layernorm.reset_parameters()
432
+
433
+ # Iterate over the sequential block to reset parameters
434
+ for module in self.process:
435
+ if hasattr(module, "reset_parameters"):
436
+ module.reset_parameters()
437
+
438
+
439
+ class TransformerBlock(nn.Module):
440
+ """
441
+ Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
442
+ which is also what is seen in e.g.
443
+ https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
444
+ and is recommended by https://arxiv.org/abs/2002.04745
445
+
446
+ """
447
+
448
+ def __init__(
449
+ self,
450
+ seq_len,
451
+ d_model,
452
+ n_heads,
453
+ relative_position_embedding=False,
454
+ source_size=None,
455
+ utility_tokens=0,
456
+ mlp_ratio=4,
457
+ activation: nn.Module = nn.ReLU,
458
+ activation_kwargs: Optional[dict] = None,
459
+ ff_linear_module_up=None,
460
+ ff_linear_module_down=None,
461
+ msa_scaling="d",
462
+ ff_dropout=0.0,
463
+ ff_inner_dropout=0.0,
464
+ ff_outer_dropout=0.0,
465
+ msa_dropout=0.0,
466
+ identity_probability=0.0,
467
+ causal=False,
468
+ linear_module=nn.Linear,
469
+ pre_norm=True,
470
+ post_norm=False,
471
+ normformer=False,
472
+ checkpoint_ff=True,
473
+ ):
474
+ """
475
+ Args:
476
+ msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
477
+ to mimic the original Attention is All You Need approach of
478
+ dividing by the sqrt of the embedding Dimension or "d" per
479
+ "Tensor Programs V...". Default "d"
480
+ """
481
+
482
+ super().__init__()
483
+
484
+ self.pre_norm = pre_norm
485
+ self.post_norm = post_norm
486
+ self.normformer = normformer
487
+
488
+ self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
489
+
490
+ self.layer_norm_1 = nn.LayerNorm(d_model)
491
+ self.layer_norm_2 = nn.LayerNorm(d_model)
492
+ self.layer_norm_3 = nn.LayerNorm(d_model)
493
+
494
+ if relative_position_embedding:
495
+ max_freq = int(max(source_size) / 2) # Suggested by Gemini!
496
+ if d_model < 16:
497
+ dim = d_model
498
+ else:
499
+ dim = 16
500
+ self.rotary_embedding = RotaryEmbedding(
501
+ dim=dim, freqs_for="pixel", max_freq=max_freq
502
+ )
503
+ else:
504
+ self.rotary_embedding = None
505
+
506
+ self.attn = MHAttention( # Handles QKV projection
507
+ d_model,
508
+ n_heads,
509
+ dropout=msa_dropout,
510
+ causal=causal,
511
+ seq_len=seq_len,
512
+ linear_module=linear_module,
513
+ rotary_embedding=self.rotary_embedding,
514
+ source_size=source_size,
515
+ utility_tokens=utility_tokens,
516
+ scaling=msa_scaling,
517
+ )
518
+
519
+ # Submodule for the feedforward process
520
+ self.ff = FeedforwardBlock(
521
+ d_model,
522
+ mlp_ratio,
523
+ d_model,
524
+ activation=activation,
525
+ activation_kwargs=activation_kwargs,
526
+ dropout=ff_dropout,
527
+ inner_dropout=ff_inner_dropout,
528
+ outer_dropout=ff_outer_dropout,
529
+ linear_module_up=(
530
+ ff_linear_module_up
531
+ if ff_linear_module_up is not None
532
+ else linear_module
533
+ ),
534
+ linear_module_down=(
535
+ ff_linear_module_down
536
+ if ff_linear_module_down is not None
537
+ else linear_module
538
+ ),
539
+ pre_norm=False, # Handled outside the block
540
+ normformer=normformer,
541
+ post_norm=False, # Handled outside the block
542
+ residual_path=False, # Handled outside the block
543
+ checkpoint=checkpoint_ff,
544
+ )
545
+
546
+ self.reset_parameters()
547
+
548
+ @property
549
+ def _kv_distance(self) -> float:
550
+ return self.attn._kv_distance
551
+
552
+ def forward(self, x):
553
+
554
+ if self.pre_norm:
555
+ x = self.layer_norm_1(x)
556
+ x = x + self.drop_path(self.attn(x, x, x))
557
+ x = self.layer_norm_2(x)
558
+ x = x + self.drop_path(self.ff(x))
559
+ if self.post_norm: # i.e. in addition! Pre and post.
560
+ x = self.layer_norm_3(x)
561
+ elif self.post_norm: # i.e. only, not prenorm, just post
562
+ x = x + self.drop_path(self.attn(x, x, x))
563
+ x = self.layer_norm_1(x)
564
+ x = x + self.drop_path(self.ff(x))
565
+ x = self.layer_norm_2(x)
566
+ else: # Not pre or post norm. Stand well back.
567
+ x = x + self.drop_path(self.attn(x, x, x))
568
+ x = x + self.drop_path(self.ff(x))
569
+
570
+ return x
571
+
572
+ def attention_logits(self, x):
573
+ """
574
+ Give back the attention scores used in this layer.
575
+ """
576
+ if self.pre_norm:
577
+ x = self.layer_norm_1(x)
578
+ return self.attn.attention_logits(x, x, x)
579
+ else:
580
+ return self.attn.attention_logits(x, x, x)
581
+
582
+ def reset_parameters(self):
583
+ self.layer_norm_1.reset_parameters()
584
+ self.layer_norm_2.reset_parameters()
585
+ self.layer_norm_3.reset_parameters()
586
+
587
+ self.attn.reset_parameters()
588
+ self.ff.reset_parameters()
589
+
590
+
591
+ class TransformerEncoder(nn.Module):
592
+ """
593
+ This assumes we already get a sequence of embeddings (e.g. word or image
594
+ patch embeddings).
595
+ """
596
+
597
+ def __init__(
598
+ self,
599
+ seq_len,
600
+ d_model,
601
+ n_layers,
602
+ n_heads,
603
+ absolute_position_embedding=True,
604
+ relative_position_embedding=False,
605
+ source_size=None,
606
+ mlp_ratio=4,
607
+ activation: nn.Module = nn.ReLU,
608
+ activation_kwargs: Optional[dict] = None,
609
+ ff_linear_module_up=None,
610
+ ff_linear_module_down=None,
611
+ ff_dropout=0.0,
612
+ ff_inner_dropout=0.0,
613
+ ff_outer_dropout=0.0,
614
+ msa_dropout=0.0,
615
+ stochastic_depth=0.0,
616
+ causal=False,
617
+ linear_module=nn.Linear,
618
+ utility_tokens=0,
619
+ return_utility_tokens=False,
620
+ pre_norm=True,
621
+ post_norm=False,
622
+ normformer=False,
623
+ msa_scaling="d",
624
+ checkpoint_ff=True,
625
+ ):
626
+ """
627
+ Args:
628
+ msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
629
+ to mimic the original Attention is All You Need approach of
630
+ dividing by the sqrt of the embedding Dimension or "d" per
631
+ "Tensor Programs V...". Default "d"
632
+ """
633
+
634
+ if relative_position_embedding and (source_size is None):
635
+ raise ValueError(
636
+ "`source_size` for TransformerEncoder cannot be None if"
637
+ " `relative_position_embedding` is True"
638
+ )
639
+
640
+ if absolute_position_embedding and (seq_len is None):
641
+ raise ValueError(
642
+ "`seq_len` for TransformerEncoder cannot be None if"
643
+ " `absolute_position_embedding` is True"
644
+ )
645
+
646
+ super().__init__()
647
+ self.seq_len = seq_len
648
+ self.n_heads = n_heads
649
+ self._utility_tokens = utility_tokens
650
+ self.return_utility_tokens = return_utility_tokens
651
+
652
+ # Initialise utility tokens with normal init, like usual Pytorch embeddings
653
+ if self._utility_tokens:
654
+ self._utility_token_embedding = nn.Parameter(
655
+ torch.empty(self._utility_tokens, d_model)
656
+ )
657
+ nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
658
+ else:
659
+ self._utility_token_embedding = None
660
+
661
+ if self._utility_tokens and (self.seq_len is not None):
662
+ self.full_sequence_length = self.seq_len + self._utility_tokens
663
+ else:
664
+ self.full_sequence_length = self.seq_len
665
+
666
+ self.d_model = d_model
667
+
668
+ if absolute_position_embedding:
669
+ self.absolute_position_embedding = nn.Embedding(
670
+ self.full_sequence_length, d_model
671
+ )
672
+ else:
673
+ self.absolute_position_embedding = None
674
+
675
+ self.mlp_dropout = ff_dropout
676
+ self.msa_dropout = msa_dropout
677
+ self.stochastic_depth = stochastic_depth
678
+
679
+ assert isinstance(n_layers, int)
680
+
681
+ if n_layers == 1:
682
+ self.stochastic_depth_probabilities = [0.0]
683
+ else:
684
+ step_size = self.stochastic_depth / (n_layers - 1)
685
+ self.stochastic_depth_probabilities = [
686
+ i * step_size for i in range(n_layers)
687
+ ]
688
+
689
+ self.blocks = nn.ModuleList(
690
+ [
691
+ TransformerBlock(
692
+ self.full_sequence_length,
693
+ d_model,
694
+ n_heads,
695
+ relative_position_embedding=relative_position_embedding,
696
+ source_size=source_size,
697
+ utility_tokens=utility_tokens,
698
+ mlp_ratio=mlp_ratio,
699
+ activation=activation,
700
+ activation_kwargs=activation_kwargs,
701
+ ff_linear_module_up=ff_linear_module_up,
702
+ ff_linear_module_down=ff_linear_module_down,
703
+ msa_scaling=msa_scaling,
704
+ ff_dropout=ff_dropout,
705
+ ff_inner_dropout=ff_inner_dropout,
706
+ ff_outer_dropout=ff_outer_dropout,
707
+ msa_dropout=msa_dropout,
708
+ identity_probability=self.stochastic_depth_probabilities[i],
709
+ causal=causal,
710
+ linear_module=linear_module,
711
+ pre_norm=pre_norm,
712
+ post_norm=post_norm,
713
+ normformer=normformer,
714
+ checkpoint_ff=checkpoint_ff,
715
+ )
716
+ for i in range(n_layers)
717
+ ]
718
+ )
719
+
720
+ self.reset_parameters()
721
+
722
+ @property
723
+ def _kv_distances(self) -> float:
724
+ return ",".join([str(block._kv_distance) for block in self.blocks])
725
+
726
+ def preprocess(self, x):
727
+ if self._utility_tokens:
728
+ x = torch.cat(
729
+ [self._utility_token_embedding.expand(x.size(0), -1, -1), x], dim=1
730
+ )
731
+ else:
732
+ x = x
733
+
734
+ if self.absolute_position_embedding is not None:
735
+ x = x + self.absolute_position_embedding(
736
+ torch.arange(
737
+ 0, self.full_sequence_length, dtype=torch.long, device=x.device
738
+ ).unsqueeze(
739
+ 0
740
+ ) # to shape (1, seq_len) to broadcast over batch
741
+ )
742
+
743
+ return x
744
+
745
+ def forward(self, x):
746
+
747
+ x = self.preprocess(x)
748
+
749
+ for block in self.blocks:
750
+ x = block(x)
751
+
752
+ if self._utility_tokens and not self.return_utility_tokens:
753
+ return x[:, self._utility_tokens :, :]
754
+ else:
755
+ return x
756
+
757
+ def attention_logits(self, x):
758
+
759
+ x = self.preprocess(x)
760
+
761
+ layer_scores = []
762
+
763
+ for block in self.blocks:
764
+ # Get attention scores with shape (batch, 1, head, seq_len, seq_len)
765
+ layer_attention_logits = block.attention_logits(x).unsqueeze(1)
766
+ layer_scores.append(layer_attention_logits)
767
+ x = block(x)
768
+
769
+ return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
770
+
771
+ def reset_parameters(self):
772
+ if self._utility_token_embedding is not None:
773
+ nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
774
+
775
+ if self.absolute_position_embedding is not None:
776
+ self.absolute_position_embedding.reset_parameters()
777
+
778
+ for block in self.blocks:
779
+ block.reset_parameters()