broccoli-ml 0.36.0__py3-none-any.whl → 10.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -1,16 +1,72 @@
1
+ import warnings
1
2
  import math
2
- from collections import OrderedDict
3
- from typing import Optional
4
- from numpy import random
3
+ from typing import Optional, Tuple
5
4
 
6
5
  import torch
7
6
  import torch.nn as nn
8
7
  import torch.nn.functional as F
8
+ from torch.utils.checkpoint import checkpoint
9
9
 
10
10
  from einops import rearrange
11
11
 
12
12
  from .rope import RotaryEmbedding, apply_rotary_emb
13
- from .linear import AnchoredLinear, SpectralNormLinear
13
+
14
+ try:
15
+ from flash_attn import flash_attn_func
16
+
17
+ print("Using flash-attn.")
18
+ FLASH_ATTN = True
19
+ except ImportError:
20
+ pass
21
+ FLASH_ATTN = False
22
+
23
+
24
+ class LayerScale(nn.Module):
25
+ def __init__(self, dim, init_values=1e-4):
26
+ super().__init__()
27
+ self.nondecay_scale = nn.Parameter(init_values * torch.ones(dim))
28
+
29
+ def forward(self, x):
30
+ return x * self.nondecay_scale
31
+
32
+
33
+ def drop_path(
34
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
35
+ ):
36
+ """
37
+ From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
38
+ Copyright 2019 Ross Wightman
39
+ See documentation and licence there.
40
+ """
41
+ if drop_prob == 0.0 or not training:
42
+ return x
43
+ keep_prob = 1 - drop_prob
44
+ shape = (x.shape[0],) + (1,) * (
45
+ x.ndim - 1
46
+ ) # work with diff dim tensors, not just 2D ConvNets
47
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
48
+ if keep_prob > 0.0 and scale_by_keep:
49
+ random_tensor.div_(keep_prob)
50
+ return x * random_tensor
51
+
52
+
53
+ class DropPath(nn.Module):
54
+ """
55
+ From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
56
+ Copyright 2019 Ross Wightman
57
+ See documentation and licence there.
58
+ """
59
+
60
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
61
+ super(DropPath, self).__init__()
62
+ self.drop_prob = drop_prob
63
+ self.scale_by_keep = scale_by_keep
64
+
65
+ def forward(self, x):
66
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
67
+
68
+ def extra_repr(self):
69
+ return f"drop_prob={round(self.drop_prob, 3):0.3f}"
14
70
 
15
71
 
16
72
  class MHAttention(nn.Module):
@@ -31,10 +87,19 @@ class MHAttention(nn.Module):
31
87
  causal=False,
32
88
  seq_len=None,
33
89
  linear_module: nn.Module = nn.Linear,
34
- bos_tokens=0,
90
+ utility_tokens=0,
91
+ talking_heads=False,
35
92
  rotary_embedding=None,
36
93
  source_size=None,
94
+ scaling="d",
37
95
  ):
96
+ """
97
+ Args:
98
+ scaling: how should the attention logits be scaled? Can be "sqrtd"
99
+ to mimic the original Attention is All You Need approach of
100
+ dividing by the sqrt of the embedding Dimension or "d" per
101
+ "Tensor Programs V...". Default "d"
102
+ """
38
103
  super().__init__()
39
104
 
40
105
  if rotary_embedding is not None:
@@ -42,12 +107,31 @@ class MHAttention(nn.Module):
42
107
  if causal:
43
108
  assert seq_len is not None
44
109
 
110
+ self.talking_heads = talking_heads
111
+
112
+ if self.talking_heads:
113
+ self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
114
+ self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
115
+ else:
116
+ self.head_projection = None
117
+ self.sample_projection = None
118
+
45
119
  self.embed_dim = embed_dim
46
120
  self.n_heads = n_heads
47
121
  assert embed_dim % n_heads == 0
122
+ self.scaling = scaling
48
123
 
49
124
  self.head_dim = self.embed_dim // self.n_heads
50
125
 
126
+ if self.scaling == "sqrtd":
127
+ self.scaling_factor = 1 / math.sqrt(self.head_dim)
128
+ elif self.scaling == "d":
129
+ # 8/d_model for backwards compatibility,
130
+ # per https://github.com/microsoft/mup
131
+ self.scaling_factor = 8 / self.head_dim
132
+ else:
133
+ raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
134
+
51
135
  self.q_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
52
136
  self.k_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
53
137
  self.v_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
@@ -66,7 +150,9 @@ class MHAttention(nn.Module):
66
150
  )
67
151
  self.rotary_embedding = rotary_embedding
68
152
  self.source_size = source_size
69
- self.bos_tokens = bos_tokens
153
+ self.utility_tokens = utility_tokens
154
+
155
+ self.reset_parameters()
70
156
 
71
157
  @property
72
158
  def _kv_distance(self) -> float:
@@ -87,7 +173,71 @@ class MHAttention(nn.Module):
87
173
 
88
174
  return 1 - similarity
89
175
 
90
- def forward(self, q, k, v):
176
+ def add_axial_rope(
177
+ self, q: torch.Tensor, k: torch.Tensor
178
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
179
+ """
180
+ Apply Axial RoPE to all tokens except utility tokens
181
+ """
182
+
183
+ if len(self.source_size) == 1:
184
+ spatial_dimension_names = "D1"
185
+ spatial_dimension_values = {"D1": self.source_size[0]}
186
+ elif len(self.source_size) == 2:
187
+ spatial_dimension_names = "D1 D2"
188
+ spatial_dimension_values = {
189
+ "D1": self.source_size[0],
190
+ "D2": self.source_size[1],
191
+ }
192
+ elif len(self.source_size) == 3:
193
+ spatial_dimension_names = "D1 D2 D3"
194
+ spatial_dimension_values = {
195
+ "D1": self.source_size[0],
196
+ "D2": self.source_size[1],
197
+ "D3": self.source_size[2],
198
+ }
199
+ else:
200
+ raise NotImplementedError(
201
+ "`source_size` must be a tuple of 1, 2 or 3 integers"
202
+ )
203
+
204
+ q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
205
+ k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
206
+
207
+ q_img = rearrange(
208
+ q_img,
209
+ f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
210
+ **spatial_dimension_values,
211
+ )
212
+ k_img = rearrange(
213
+ k_img,
214
+ f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
215
+ **spatial_dimension_values,
216
+ )
217
+
218
+ freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
219
+
220
+ q_img = apply_rotary_emb(freqs, q_img)
221
+ k_img = apply_rotary_emb(freqs, k_img)
222
+
223
+ q_img = rearrange(
224
+ q_img,
225
+ f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
226
+ )
227
+ k_img = rearrange(
228
+ k_img,
229
+ f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
230
+ )
231
+
232
+ # Re-combine the utility tokens and the RoPE-enhanced sequence tokens
233
+ q = torch.cat([q_util, q_img], dim=1)
234
+ k = torch.cat([k_util, k_img], dim=1)
235
+
236
+ return q, k
237
+
238
+ def project_qkv(
239
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
240
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
91
241
  query_batch_size, query_tokens, query_features = q.size()
92
242
  key_batch_size, key_tokens, key_features = k.size()
93
243
 
@@ -100,66 +250,74 @@ class MHAttention(nn.Module):
100
250
 
101
251
  if self.causal:
102
252
  assert query_tokens == key_tokens
103
- assert query_tokens == self.sequence_length
253
+ assert query_tokens == self.seq_len
104
254
 
105
- # Project q, k and v
106
- q = self.q_proj(q)
107
- k = self.k_proj(k)
108
- v = self.v_proj(v)
255
+ q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
109
256
 
110
- # Rearrange dimensions and add RoPE if needed
111
257
  if self.rotary_embedding is not None:
258
+ q, k = self.add_axial_rope(q, k)
112
259
 
113
- if len(self.source_size) == 1:
114
- spatial_dimension_names = "D1"
115
- spatial_dimension_values = {"D1": self.source_size[0]}
116
- elif len(self.source_size) == 2:
117
- spatial_dimension_names = "D1 D2"
118
- spatial_dimension_values = {
119
- "D1": self.source_size[0],
120
- "D2": self.source_size[1],
121
- }
122
- elif len(self.source_size) == 3:
123
- spatial_dimension_names = "D1 D2 D3"
124
- spatial_dimension_values = {
125
- "D1": self.source_size[0],
126
- "D2": self.source_size[1],
127
- "D3": self.source_size[2],
128
- }
129
- else:
130
- raise NotImplementedError(
131
- "`source_size` must be a tuple of 1, 2 or 3 integers"
132
- )
260
+ return q, k, v
133
261
 
134
- q_bos, q_img = q[:, : self.bos_tokens, :], q[:, self.bos_tokens :, :]
135
- k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
262
+ def forward(self, q, k, v):
136
263
 
137
- q_img = rearrange(
138
- q_img,
139
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
140
- **spatial_dimension_values,
264
+ q, k, v = self.project_qkv(q, k, v)
265
+
266
+ if FLASH_ATTN and not self.talking_heads:
267
+ # Divide Q/K/V into heads
268
+ q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
269
+ k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
270
+ v = rearrange(v, "b t (h d) -> b t h d", h=self.n_heads)
271
+
272
+ output_with_heads = flash_attn_func(
273
+ q,
274
+ k,
275
+ v,
276
+ dropout_p=self.dropout.p if self.training else 0.0,
277
+ softmax_scale=self.scaling_factor,
278
+ causal=self.causal,
141
279
  )
142
- k_img = rearrange(
143
- k_img,
144
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
145
- **spatial_dimension_values,
146
- )
147
- freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
148
- q_img = apply_rotary_emb(freqs, q_img)
149
- k_img = apply_rotary_emb(freqs, k_img)
150
280
 
151
- q_img = rearrange(
152
- q_img,
153
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
154
- )
155
- k_img = rearrange(
156
- k_img,
157
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
158
- )
281
+ output_without_heads = rearrange(output_with_heads, "b t h d -> b t (h d)")
159
282
 
160
- # Re-combine the BOS tokens and the RoPE-enhanced image tokens
161
- q = torch.cat([q_bos, q_img], dim=1)
162
- k = torch.cat([k_bos, k_img], dim=1)
283
+ return self.out_proj(output_without_heads)
284
+ else:
285
+ # Divide Q/K/V into heads
286
+ q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
287
+ k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
288
+ v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
289
+
290
+ qk_scores = q @ k.transpose(-1, -2)
291
+
292
+ qk_scores *= self.scaling_factor
293
+
294
+ if self.talking_heads:
295
+ qk_scores = torch.einsum(
296
+ "b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
297
+ )
298
+
299
+ # Apply mask if causal (must come before softmax)
300
+ if self.causal:
301
+ qk_scores.masked_fill_(self.mask, float("-inf"))
302
+
303
+ qk_scores = F.softmax(qk_scores, dim=-1)
304
+
305
+ if self.talking_heads:
306
+ qk_scores = torch.einsum(
307
+ "b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
308
+ )
309
+
310
+ qk_scores = self.dropout(qk_scores)
311
+
312
+ output_with_heads = qk_scores @ v
313
+
314
+ output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
315
+
316
+ return self.out_proj(output_without_heads)
317
+
318
+ def attention_logits(self, q, k, v):
319
+
320
+ q, k, v = self.project_qkv(q, k, v)
163
321
 
164
322
  # Divide Q/K/V into heads
165
323
  q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
@@ -168,19 +326,24 @@ class MHAttention(nn.Module):
168
326
 
169
327
  qk_scores = q @ k.transpose(-1, -2)
170
328
 
171
- qk_scores /= math.sqrt(self.head_dim)
329
+ qk_scores *= self.scaling_factor
172
330
 
173
331
  # Apply mask if causal (must come before softmax)
174
332
  if self.causal:
175
333
  qk_scores.masked_fill_(self.mask, float("-inf"))
176
334
 
177
- qk_scores = F.softmax(qk_scores, dim=-1)
178
-
179
- output_with_heads = qk_scores @ v
335
+ return qk_scores # (batch, head, seq_len, seq_len)
180
336
 
181
- output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
182
-
183
- return self.out_proj(output_without_heads)
337
+ def reset_parameters(self):
338
+ # Default nn.Linear init is kaiming_uniform, which is fine
339
+ self.q_proj.reset_parameters()
340
+ self.k_proj.reset_parameters()
341
+ self.v_proj.reset_parameters()
342
+ self.out_proj.reset_parameters()
343
+ if self.talking_heads:
344
+ # Initialize close to identity
345
+ nn.init.eye_(self.head_projection.weight)
346
+ nn.init.eye_(self.sample_projection.weight)
184
347
 
185
348
 
186
349
  class FeedforwardBlock(nn.Module):
@@ -196,17 +359,29 @@ class FeedforwardBlock(nn.Module):
196
359
  activation=nn.ReLU,
197
360
  activation_kwargs=None,
198
361
  dropout=0.0,
362
+ inner_dropout=None,
363
+ outer_dropout=None,
199
364
  linear_module_up=nn.Linear,
200
365
  linear_module_down=nn.Linear,
201
366
  pre_norm=True,
202
367
  normformer=False,
203
368
  post_norm=True,
204
369
  residual_path=True,
370
+ checkpoint=True,
205
371
  ):
206
372
  super().__init__()
207
373
 
374
+ self.checkpoint = checkpoint
208
375
  self.residual_path = residual_path
209
376
  self.post_norm = post_norm
377
+ self.xglu = activation.__name__.endswith("GLU")
378
+
379
+ if self.residual_path and (output_features < input_features):
380
+ raise ValueError(
381
+ "If the number of output features will be less than "
382
+ "the number of input features, then `residual_path` "
383
+ "should be set to False."
384
+ )
210
385
 
211
386
  if self.post_norm:
212
387
  self.layernorm = nn.LayerNorm(output_features)
@@ -216,32 +391,91 @@ class FeedforwardBlock(nn.Module):
216
391
  else:
217
392
  self.activation = activation()
218
393
 
219
- self.dropout = nn.Dropout(dropout)
394
+ self.inner_dropout = nn.Dropout(
395
+ inner_dropout if inner_dropout is not None else dropout
396
+ )
397
+ self.outer_dropout = nn.Dropout(
398
+ outer_dropout if outer_dropout is not None else dropout
399
+ )
220
400
 
221
401
  self.max_features = (
222
- 2 * ratio * output_features
223
- if activation.__name__.endswith("GLU")
224
- else ratio * output_features
402
+ 2 * int(ratio * output_features)
403
+ if self.xglu
404
+ else int(ratio * output_features)
405
+ )
406
+
407
+ self.linear_in = linear_module_up(input_features, self.max_features)
408
+ self.linear_out = linear_module_down(
409
+ int(ratio * output_features), output_features
225
410
  )
226
411
 
227
412
  self.process = nn.Sequential(
228
413
  *[
229
414
  nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
230
- linear_module_up(input_features, self.max_features),
415
+ self.linear_in,
231
416
  self.activation,
232
- nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
233
- linear_module_down(ratio * output_features, output_features),
234
- self.dropout,
417
+ self.inner_dropout,
418
+ (
419
+ nn.LayerNorm(int(ratio * output_features))
420
+ if normformer
421
+ else nn.Identity()
422
+ ),
423
+ self.linear_out,
424
+ self.outer_dropout,
235
425
  ]
236
426
  )
237
427
 
428
+ self.recycling_enabled = False
429
+ if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
430
+ self.linear_out, "column_recycling_rate"
431
+ ):
432
+ self.recycling_enabled = True
433
+ self.master_recycling_rate = self.linear_in.row_recycling_rate
434
+ self.linear_in.row_recycling_rate = 0.0
435
+ self.linear_out.column_recycling_rate = 0.0
436
+ if (
437
+ hasattr(self.linear_in, "column_recycling_rate")
438
+ and self.linear_in.column_recycling_rate > 0
439
+ ) or (
440
+ hasattr(self.linear_out, "row_recycling_rate")
441
+ and self.linear_out.row_recycling_rate > 0
442
+ ):
443
+ raise NotImplementedError(
444
+ "At the moment this layer can only support recycling linear "
445
+ "layers if the in layer resets only rows and the out layer "
446
+ "resets only columns."
447
+ )
448
+
449
+ self.reset_parameters()
450
+
238
451
  def forward(self, x):
452
+
453
+ # Recycle weights if using recycling linear layers
454
+ if self.training and self.recycling_enabled:
455
+ indices = self.linear_out.get_reset_indices(1)
456
+ self.linear_in.reset_rows(indices, incoming_data=x)
457
+ self.linear_out.reset_columns(indices)
458
+
459
+ if self.checkpoint:
460
+ processed = checkpoint(self.process, x, use_reentrant=False)
461
+ else:
462
+ processed = self.process(x)
463
+
239
464
  if self.residual_path and self.post_norm:
240
- return self.layernorm(x + self.process(x))
465
+ return self.layernorm(x + processed)
241
466
  elif self.residual_path:
242
- return x + self.process(x)
467
+ return x + processed
243
468
  else:
244
- return self.process(x)
469
+ return processed
470
+
471
+ def reset_parameters(self):
472
+ if self.post_norm:
473
+ self.layernorm.reset_parameters()
474
+
475
+ # Iterate over the sequential block to reset parameters
476
+ for module in self.process:
477
+ if hasattr(module, "reset_parameters"):
478
+ module.reset_parameters()
245
479
 
246
480
 
247
481
  class TransformerBlock(nn.Module):
@@ -258,13 +492,19 @@ class TransformerBlock(nn.Module):
258
492
  seq_len,
259
493
  d_model,
260
494
  n_heads,
261
- position_embedding_type="absolute", # absolute or relative
495
+ relative_position_embedding=False,
262
496
  source_size=None,
263
- bos_tokens=0,
497
+ utility_tokens=0,
498
+ talking_heads=False,
264
499
  mlp_ratio=4,
265
500
  activation: nn.Module = nn.ReLU,
266
501
  activation_kwargs: Optional[dict] = None,
267
- mlp_dropout=0.0,
502
+ ff_linear_module_up=None,
503
+ ff_linear_module_down=None,
504
+ msa_scaling="d",
505
+ ff_dropout=0.0,
506
+ ff_inner_dropout=0.0,
507
+ ff_outer_dropout=0.0,
268
508
  msa_dropout=0.0,
269
509
  identity_probability=0.0,
270
510
  causal=False,
@@ -272,19 +512,37 @@ class TransformerBlock(nn.Module):
272
512
  pre_norm=True,
273
513
  post_norm=False,
274
514
  normformer=False,
515
+ checkpoint_ff=True,
516
+ layerscale=True,
275
517
  ):
518
+ """
519
+ Args:
520
+ msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
521
+ to mimic the original Attention is All You Need approach of
522
+ dividing by the sqrt of the embedding Dimension or "d" per
523
+ "Tensor Programs V...". Default "d"
524
+ """
525
+
276
526
  super().__init__()
277
527
 
278
528
  self.pre_norm = pre_norm
279
529
  self.post_norm = post_norm
280
530
  self.normformer = normformer
281
531
 
282
- self.identity_probability = identity_probability
532
+ self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
283
533
 
284
534
  self.layer_norm_1 = nn.LayerNorm(d_model)
285
535
  self.layer_norm_2 = nn.LayerNorm(d_model)
536
+ self.layer_norm_3 = nn.LayerNorm(d_model)
286
537
 
287
- if position_embedding_type == "relative":
538
+ if layerscale:
539
+ self.layerscale1 = LayerScale(d_model)
540
+ self.layerscale2 = LayerScale(d_model)
541
+ else:
542
+ self.layerscale1 = nn.Identity()
543
+ self.layerscale2 = nn.Identity()
544
+
545
+ if relative_position_embedding:
288
546
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
289
547
  if d_model < 16:
290
548
  dim = d_model
@@ -305,7 +563,9 @@ class TransformerBlock(nn.Module):
305
563
  linear_module=linear_module,
306
564
  rotary_embedding=self.rotary_embedding,
307
565
  source_size=source_size,
308
- bos_tokens=bos_tokens,
566
+ utility_tokens=utility_tokens,
567
+ talking_heads=talking_heads,
568
+ scaling=msa_scaling,
309
569
  )
310
570
 
311
571
  # Submodule for the feedforward process
@@ -315,56 +575,75 @@ class TransformerBlock(nn.Module):
315
575
  d_model,
316
576
  activation=activation,
317
577
  activation_kwargs=activation_kwargs,
318
- dropout=mlp_dropout,
319
- linear_module_up=linear_module,
320
- linear_module_down=linear_module,
321
- pre_norm=pre_norm,
578
+ dropout=ff_dropout,
579
+ inner_dropout=ff_inner_dropout,
580
+ outer_dropout=ff_outer_dropout,
581
+ linear_module_up=(
582
+ ff_linear_module_up
583
+ if ff_linear_module_up is not None
584
+ else linear_module
585
+ ),
586
+ linear_module_down=(
587
+ ff_linear_module_down
588
+ if ff_linear_module_down is not None
589
+ else linear_module
590
+ ),
591
+ pre_norm=False, # Handled outside the block
322
592
  normformer=normformer,
323
- post_norm=post_norm,
324
- residual_path=True,
593
+ post_norm=False, # Handled outside the block
594
+ residual_path=False, # Handled outside the block
595
+ checkpoint=checkpoint_ff,
325
596
  )
326
597
 
598
+ self.reset_parameters()
599
+
327
600
  @property
328
601
  def _kv_distance(self) -> float:
329
602
  return self.attn._kv_distance
330
603
 
331
604
  def forward(self, x):
332
- if not self.training:
333
- identity_probability = 0.0
334
- else:
335
- identity_probability = self.identity_probability
336
-
337
- # perform the identity operation for some rows in the batch
338
- dist = torch.distributions.Binomial(x.size(0), identity_probability)
339
- identity_count = int(dist.sample().item())
340
-
341
- shuffle_indices = torch.randperm(x.size(0), device=x.device)
342
- unshuffle_indices = torch.argsort(shuffle_indices)
343
- shuffled = x[shuffle_indices, :, :]
344
- identity_x = shuffled[:identity_count, :, :]
345
- process_x = shuffled[identity_count:, :, :]
346
-
347
- residual_x = process_x
348
605
 
349
606
  if self.pre_norm:
350
- process_x = self.layer_norm_1(process_x)
607
+ x = self.layer_norm_1(x)
608
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
609
+ x = self.layer_norm_2(x)
610
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
611
+ if self.post_norm: # i.e. in addition! Pre and post.
612
+ x = self.layer_norm_3(x)
613
+ elif self.post_norm: # i.e. only, not prenorm, just post
614
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
615
+ x = self.layer_norm_1(x)
616
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
617
+ x = self.layer_norm_2(x)
618
+ else: # Not pre or post norm. Stand well back.
619
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
620
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
351
621
 
352
- process_x = residual_x + self.attn(process_x, process_x, process_x)
353
-
354
- if self.post_norm:
355
- process_x = self.layer_norm_2(process_x)
622
+ return x
356
623
 
357
- process_x = self.ff(process_x)
624
+ def attention_logits(self, x):
625
+ """
626
+ Give back the attention scores used in this layer.
627
+ """
628
+ if self.pre_norm:
629
+ x = self.layer_norm_1(x)
630
+ return self.attn.attention_logits(x, x, x)
631
+ else:
632
+ return self.attn.attention_logits(x, x, x)
358
633
 
359
- x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
634
+ def reset_parameters(self):
635
+ self.layer_norm_1.reset_parameters()
636
+ self.layer_norm_2.reset_parameters()
637
+ self.layer_norm_3.reset_parameters()
360
638
 
361
- return x
639
+ self.attn.reset_parameters()
640
+ self.ff.reset_parameters()
362
641
 
363
642
 
364
643
  class TransformerEncoder(nn.Module):
365
644
  """
366
645
  This assumes we already get a sequence of embeddings (e.g. word or image
367
- patch embeddings). It uses learned positional embeddings.
646
+ patch embeddings).
368
647
  """
369
648
 
370
649
  def __init__(
@@ -373,54 +652,93 @@ class TransformerEncoder(nn.Module):
373
652
  d_model,
374
653
  n_layers,
375
654
  n_heads,
376
- position_embedding_type="absolute", # absolute or relative
655
+ absolute_position_embedding=True,
656
+ relative_position_embedding=False,
377
657
  source_size=None,
378
658
  mlp_ratio=4,
379
659
  activation: nn.Module = nn.ReLU,
380
660
  activation_kwargs: Optional[dict] = None,
381
- mlp_dropout=0.0,
661
+ ff_linear_module_up=None,
662
+ ff_linear_module_down=None,
663
+ ff_dropout=0.0,
664
+ ff_inner_dropout=0.0,
665
+ ff_outer_dropout=0.0,
382
666
  msa_dropout=0.0,
383
667
  stochastic_depth=0.0,
384
668
  causal=False,
385
669
  linear_module=nn.Linear,
386
- bos_tokens=0,
387
- return_bos_tokens=False,
670
+ utility_tokens=0,
671
+ talking_heads=False,
672
+ return_utility_tokens=False,
388
673
  pre_norm=True,
389
674
  post_norm=False,
390
675
  normformer=False,
676
+ msa_scaling="d",
677
+ checkpoint_ff=True,
678
+ layerscale=True,
391
679
  ):
392
- if position_embedding_type == "relative":
393
- assert source_size is not None # TODO: make this a proper exception
680
+ """
681
+ Args:
682
+ msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
683
+ to mimic the original Attention is All You Need approach of
684
+ dividing by the sqrt of the embedding Dimension or "d" per
685
+ "Tensor Programs V...". Default "d"
686
+ """
687
+
688
+ if relative_position_embedding and (source_size is None):
689
+ raise ValueError(
690
+ "`source_size` for TransformerEncoder cannot be None if"
691
+ " `relative_position_embedding` is True"
692
+ )
693
+
694
+ if absolute_position_embedding and (seq_len is None):
695
+ raise ValueError(
696
+ "`seq_len` for TransformerEncoder cannot be None if"
697
+ " `absolute_position_embedding` is True"
698
+ )
394
699
 
395
700
  super().__init__()
701
+
702
+ if FLASH_ATTN and talking_heads:
703
+ warnings.warn(
704
+ "Using talking heads currently prevents using flash attention.",
705
+ stacklevel=2,
706
+ )
707
+
396
708
  self.seq_len = seq_len
397
709
  self.n_heads = n_heads
398
- self._bos_tokens = bos_tokens
399
- self.return_bos_tokens = return_bos_tokens
400
-
401
- # Initialise BOS tokens with normal init, like usual Pytorch embeddings
402
- if self._bos_tokens:
403
- self._bos_embedding = nn.Parameter(torch.empty(self._bos_tokens, d_model))
404
- nn.init.normal_(self._bos_embedding, mean=0.0, std=1.0)
405
- self.full_sequence_length = self.seq_len + self._bos_tokens
710
+ self._utility_tokens = utility_tokens
711
+ self.return_utility_tokens = return_utility_tokens
712
+
713
+ # Initialise utility tokens with normal init, like usual Pytorch embeddings
714
+ if self._utility_tokens:
715
+ self._utility_token_embedding = nn.Parameter(
716
+ torch.empty(self._utility_tokens, d_model)
717
+ )
718
+ nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
719
+ else:
720
+ self._utility_token_embedding = None
721
+
722
+ if self._utility_tokens and (self.seq_len is not None):
723
+ self.full_sequence_length = self.seq_len + self._utility_tokens
406
724
  else:
407
- self._bos_embedding = None
408
725
  self.full_sequence_length = self.seq_len
409
726
 
410
727
  self.d_model = d_model
411
728
 
412
- self.position_embedding_type = position_embedding_type
413
-
414
- if self.position_embedding_type == "absolute":
729
+ if absolute_position_embedding:
415
730
  self.absolute_position_embedding = nn.Embedding(
416
731
  self.full_sequence_length, d_model
417
732
  )
733
+ else:
734
+ self.absolute_position_embedding = None
418
735
 
419
- self.mlp_dropout = mlp_dropout
736
+ self.mlp_dropout = ff_dropout
420
737
  self.msa_dropout = msa_dropout
421
738
  self.stochastic_depth = stochastic_depth
422
739
 
423
- assert isinstance(n_layers, int) # XXX: make this a proper Exception
740
+ assert isinstance(n_layers, int)
741
+
424
742
  if n_layers == 1:
425
743
  self.stochastic_depth_probabilities = [0.0]
426
744
  else:
@@ -435,13 +753,19 @@ class TransformerEncoder(nn.Module):
435
753
  self.full_sequence_length,
436
754
  d_model,
437
755
  n_heads,
438
- position_embedding_type=position_embedding_type,
756
+ relative_position_embedding=relative_position_embedding,
439
757
  source_size=source_size,
440
- bos_tokens=bos_tokens,
758
+ utility_tokens=utility_tokens,
759
+ talking_heads=talking_heads,
441
760
  mlp_ratio=mlp_ratio,
442
761
  activation=activation,
443
762
  activation_kwargs=activation_kwargs,
444
- mlp_dropout=mlp_dropout,
763
+ ff_linear_module_up=ff_linear_module_up,
764
+ ff_linear_module_down=ff_linear_module_down,
765
+ msa_scaling=msa_scaling,
766
+ ff_dropout=ff_dropout,
767
+ ff_inner_dropout=ff_inner_dropout,
768
+ ff_outer_dropout=ff_outer_dropout,
445
769
  msa_dropout=msa_dropout,
446
770
  identity_probability=self.stochastic_depth_probabilities[i],
447
771
  causal=causal,
@@ -449,22 +773,28 @@ class TransformerEncoder(nn.Module):
449
773
  pre_norm=pre_norm,
450
774
  post_norm=post_norm,
451
775
  normformer=normformer,
776
+ checkpoint_ff=checkpoint_ff,
777
+ layerscale=layerscale,
452
778
  )
453
779
  for i in range(n_layers)
454
780
  ]
455
781
  )
456
782
 
783
+ self.reset_parameters()
784
+
457
785
  @property
458
786
  def _kv_distances(self) -> float:
459
787
  return ",".join([str(block._kv_distance) for block in self.blocks])
460
788
 
461
- def forward(self, x):
462
- if self._bos_tokens:
463
- x = torch.cat([self._bos_embedding.expand(x.size(0), -1, -1), x], dim=1)
789
+ def preprocess(self, x):
790
+ if self._utility_tokens:
791
+ x = torch.cat(
792
+ [self._utility_token_embedding.expand(x.size(0), -1, -1), x], dim=1
793
+ )
464
794
  else:
465
795
  x = x
466
796
 
467
- if self.position_embedding_type == "absolute":
797
+ if self.absolute_position_embedding is not None:
468
798
  x = x + self.absolute_position_embedding(
469
799
  torch.arange(
470
800
  0, self.full_sequence_length, dtype=torch.long, device=x.device
@@ -473,10 +803,40 @@ class TransformerEncoder(nn.Module):
473
803
  ) # to shape (1, seq_len) to broadcast over batch
474
804
  )
475
805
 
806
+ return x
807
+
808
+ def forward(self, x):
809
+
810
+ x = self.preprocess(x)
811
+
476
812
  for block in self.blocks:
477
813
  x = block(x)
478
814
 
479
- if self._bos_tokens and not self.return_bos_tokens:
480
- return x[:, self._bos_tokens :, :]
815
+ if self._utility_tokens and not self.return_utility_tokens:
816
+ return x[:, self._utility_tokens :, :]
481
817
  else:
482
818
  return x
819
+
820
+ def attention_logits(self, x):
821
+
822
+ x = self.preprocess(x)
823
+
824
+ layer_scores = []
825
+
826
+ for block in self.blocks:
827
+ # Get attention scores with shape (batch, 1, head, seq_len, seq_len)
828
+ layer_attention_logits = block.attention_logits(x).unsqueeze(1)
829
+ layer_scores.append(layer_attention_logits)
830
+ x = block(x)
831
+
832
+ return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
833
+
834
+ def reset_parameters(self):
835
+ if self._utility_token_embedding is not None:
836
+ nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
837
+
838
+ if self.absolute_position_embedding is not None:
839
+ self.absolute_position_embedding.reset_parameters()
840
+
841
+ for block in self.blocks:
842
+ block.reset_parameters()