broccoli-ml 0.29.1__py3-none-any.whl → 10.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -1,16 +1,72 @@
1
+ import warnings
1
2
  import math
2
- from collections import OrderedDict
3
- from typing import Optional
4
- from numpy import random
3
+ from typing import Optional, Tuple
5
4
 
6
5
  import torch
7
6
  import torch.nn as nn
8
7
  import torch.nn.functional as F
8
+ from torch.utils.checkpoint import checkpoint
9
9
 
10
10
  from einops import rearrange
11
11
 
12
12
  from .rope import RotaryEmbedding, apply_rotary_emb
13
- from .linear import AnchoredLinear, SpectralNormLinear
13
+
14
+ try:
15
+ from flash_attn import flash_attn_func
16
+
17
+ print("Using flash-attn.")
18
+ FLASH_ATTN = True
19
+ except ImportError:
20
+ pass
21
+ FLASH_ATTN = False
22
+
23
+
24
+ class LayerScale(nn.Module):
25
+ def __init__(self, dim, init_values=1e-4):
26
+ super().__init__()
27
+ self.nondecay_scale = nn.Parameter(init_values * torch.ones(dim))
28
+
29
+ def forward(self, x):
30
+ return x * self.nondecay_scale
31
+
32
+
33
+ def drop_path(
34
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
35
+ ):
36
+ """
37
+ From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
38
+ Copyright 2019 Ross Wightman
39
+ See documentation and licence there.
40
+ """
41
+ if drop_prob == 0.0 or not training:
42
+ return x
43
+ keep_prob = 1 - drop_prob
44
+ shape = (x.shape[0],) + (1,) * (
45
+ x.ndim - 1
46
+ ) # work with diff dim tensors, not just 2D ConvNets
47
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
48
+ if keep_prob > 0.0 and scale_by_keep:
49
+ random_tensor.div_(keep_prob)
50
+ return x * random_tensor
51
+
52
+
53
+ class DropPath(nn.Module):
54
+ """
55
+ From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
56
+ Copyright 2019 Ross Wightman
57
+ See documentation and licence there.
58
+ """
59
+
60
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
61
+ super(DropPath, self).__init__()
62
+ self.drop_prob = drop_prob
63
+ self.scale_by_keep = scale_by_keep
64
+
65
+ def forward(self, x):
66
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
67
+
68
+ def extra_repr(self):
69
+ return f"drop_prob={round(self.drop_prob, 3):0.3f}"
14
70
 
15
71
 
16
72
  class MHAttention(nn.Module):
@@ -21,45 +77,6 @@ class MHAttention(nn.Module):
21
77
  are the same shape.
22
78
 
23
79
  Assumes bias=False and batch_first=True, as God intended.
24
-
25
- Optionally adds various bells and whistles suggested in the
26
- literature, including:
27
-
28
- Noam Shazeer's scaled attention per "Attention is All You Need"
29
- (https://arxiv.org/abs/1706.03762).
30
-
31
- Max subtract softmax as discussed in "Attention As An RNN"
32
- (https://arxiv.org/abs/2405.13956)
33
-
34
- Log-length scaled softmax per "Overcoming a Theoretical Limitation of
35
- Self-Attention" (https://arxiv.org/abs/2202.12172).
36
-
37
- Quiet softmax per
38
- https://www.evanmiller.org/attention-is-off-by-one.html
39
-
40
- Args:
41
- d_model: ...
42
- n_heads: ...
43
- dropout: ...
44
- causal: should a causal mask be applied to the logits before attention
45
- is applied? This is standard when using self-attention. Cannot be
46
- True if inputs won't be square (e.g. if sequence length for
47
- encoder and decoder are different)
48
- sequence_length: ...
49
- share_kv: ...
50
- linear_module: ...
51
- max_subtract: if True, the maximum logit value is subtracted from all
52
- logits before performing the softmax operation to create a more
53
- numerically stable softmax. This is discussed in "Attention As An
54
- RNN" (https://arxiv.org/abs/2405.13956).
55
- d_model_scale: ...
56
- log_length_scale: if True, multiplies logits by the log length of
57
- the decoder sequence before performing the softmax operation, as
58
- proposed in "Overcoming a Theoretical Limitation of Self-Attention"
59
- (https://arxiv.org/abs/2202.12172).
60
- quiet: if True, adds 1 to the denominator of the softmax operation,
61
- allowing some tokens to attend to no other tokens as described in
62
- https://www.evanmiller.org/attention-is-off-by-one.html.
63
80
  """
64
81
 
65
82
  def __init__(
@@ -70,10 +87,19 @@ class MHAttention(nn.Module):
70
87
  causal=False,
71
88
  seq_len=None,
72
89
  linear_module: nn.Module = nn.Linear,
73
- bos_tokens=0,
90
+ utility_tokens=0,
91
+ talking_heads=False,
74
92
  rotary_embedding=None,
75
93
  source_size=None,
94
+ scaling="d",
76
95
  ):
96
+ """
97
+ Args:
98
+ scaling: how should the attention logits be scaled? Can be "sqrtd"
99
+ to mimic the original Attention is All You Need approach of
100
+ dividing by the sqrt of the embedding Dimension or "d" per
101
+ "Tensor Programs V...". Default "d"
102
+ """
77
103
  super().__init__()
78
104
 
79
105
  if rotary_embedding is not None:
@@ -81,12 +107,31 @@ class MHAttention(nn.Module):
81
107
  if causal:
82
108
  assert seq_len is not None
83
109
 
110
+ self.talking_heads = talking_heads
111
+
112
+ if self.talking_heads:
113
+ self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
114
+ self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
115
+ else:
116
+ self.head_projection = None
117
+ self.sample_projection = None
118
+
84
119
  self.embed_dim = embed_dim
85
120
  self.n_heads = n_heads
86
121
  assert embed_dim % n_heads == 0
122
+ self.scaling = scaling
87
123
 
88
124
  self.head_dim = self.embed_dim // self.n_heads
89
125
 
126
+ if self.scaling == "sqrtd":
127
+ self.scaling_factor = 1 / math.sqrt(self.head_dim)
128
+ elif self.scaling == "d":
129
+ # 8/d_model for backwards compatibility,
130
+ # per https://github.com/microsoft/mup
131
+ self.scaling_factor = 8 / self.head_dim
132
+ else:
133
+ raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
134
+
90
135
  self.q_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
91
136
  self.k_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
92
137
  self.v_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
@@ -105,7 +150,9 @@ class MHAttention(nn.Module):
105
150
  )
106
151
  self.rotary_embedding = rotary_embedding
107
152
  self.source_size = source_size
108
- self.bos_tokens = bos_tokens
153
+ self.utility_tokens = utility_tokens
154
+
155
+ self.reset_parameters()
109
156
 
110
157
  @property
111
158
  def _kv_distance(self) -> float:
@@ -126,7 +173,71 @@ class MHAttention(nn.Module):
126
173
 
127
174
  return 1 - similarity
128
175
 
129
- def forward(self, q, k, v):
176
+ def add_axial_rope(
177
+ self, q: torch.Tensor, k: torch.Tensor
178
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
179
+ """
180
+ Apply Axial RoPE to all tokens except utility tokens
181
+ """
182
+
183
+ if len(self.source_size) == 1:
184
+ spatial_dimension_names = "D1"
185
+ spatial_dimension_values = {"D1": self.source_size[0]}
186
+ elif len(self.source_size) == 2:
187
+ spatial_dimension_names = "D1 D2"
188
+ spatial_dimension_values = {
189
+ "D1": self.source_size[0],
190
+ "D2": self.source_size[1],
191
+ }
192
+ elif len(self.source_size) == 3:
193
+ spatial_dimension_names = "D1 D2 D3"
194
+ spatial_dimension_values = {
195
+ "D1": self.source_size[0],
196
+ "D2": self.source_size[1],
197
+ "D3": self.source_size[2],
198
+ }
199
+ else:
200
+ raise NotImplementedError(
201
+ "`source_size` must be a tuple of 1, 2 or 3 integers"
202
+ )
203
+
204
+ q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
205
+ k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
206
+
207
+ q_img = rearrange(
208
+ q_img,
209
+ f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
210
+ **spatial_dimension_values,
211
+ )
212
+ k_img = rearrange(
213
+ k_img,
214
+ f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
215
+ **spatial_dimension_values,
216
+ )
217
+
218
+ freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
219
+
220
+ q_img = apply_rotary_emb(freqs, q_img)
221
+ k_img = apply_rotary_emb(freqs, k_img)
222
+
223
+ q_img = rearrange(
224
+ q_img,
225
+ f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
226
+ )
227
+ k_img = rearrange(
228
+ k_img,
229
+ f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
230
+ )
231
+
232
+ # Re-combine the utility tokens and the RoPE-enhanced sequence tokens
233
+ q = torch.cat([q_util, q_img], dim=1)
234
+ k = torch.cat([k_util, k_img], dim=1)
235
+
236
+ return q, k
237
+
238
+ def project_qkv(
239
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
240
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
130
241
  query_batch_size, query_tokens, query_features = q.size()
131
242
  key_batch_size, key_tokens, key_features = k.size()
132
243
 
@@ -139,66 +250,74 @@ class MHAttention(nn.Module):
139
250
 
140
251
  if self.causal:
141
252
  assert query_tokens == key_tokens
142
- assert query_tokens == self.sequence_length
253
+ assert query_tokens == self.seq_len
143
254
 
144
- # Project q, k and v
145
- q = self.q_proj(q)
146
- k = self.k_proj(k)
147
- v = self.v_proj(v)
255
+ q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
148
256
 
149
- # Rearrange dimensions and add RoPE if needed
150
257
  if self.rotary_embedding is not None:
258
+ q, k = self.add_axial_rope(q, k)
151
259
 
152
- if len(self.source_size) == 1:
153
- spatial_dimension_names = "D1"
154
- spatial_dimension_values = {"D1": self.source_size[0]}
155
- elif len(self.source_size) == 2:
156
- spatial_dimension_names = "D1 D2"
157
- spatial_dimension_values = {
158
- "D1": self.source_size[0],
159
- "D2": self.source_size[1],
160
- }
161
- elif len(self.source_size) == 3:
162
- spatial_dimension_names = "D1 D2 D3"
163
- spatial_dimension_values = {
164
- "D1": self.source_size[0],
165
- "D2": self.source_size[1],
166
- "D3": self.source_size[2],
167
- }
168
- else:
169
- raise NotImplementedError(
170
- "`source_size` must be a tuple of 1, 2 or 3 integers"
171
- )
260
+ return q, k, v
172
261
 
173
- q_bos, q_img = q[:, : self.bos_tokens, :], q[:, self.bos_tokens :, :]
174
- k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
262
+ def forward(self, q, k, v):
175
263
 
176
- q_img = rearrange(
177
- q_img,
178
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
179
- **spatial_dimension_values,
180
- )
181
- k_img = rearrange(
182
- k_img,
183
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
184
- **spatial_dimension_values,
264
+ q, k, v = self.project_qkv(q, k, v)
265
+
266
+ if FLASH_ATTN and not self.talking_heads:
267
+ # Divide Q/K/V into heads
268
+ q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
269
+ k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
270
+ v = rearrange(v, "b t (h d) -> b t h d", h=self.n_heads)
271
+
272
+ output_with_heads = flash_attn_func(
273
+ q,
274
+ k,
275
+ v,
276
+ dropout_p=self.dropout.p if self.training else 0.0,
277
+ softmax_scale=self.scaling_factor,
278
+ causal=self.causal,
185
279
  )
186
- freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
187
- q_img = apply_rotary_emb(freqs, q_img)
188
- k_img = apply_rotary_emb(freqs, k_img)
189
280
 
190
- q_img = rearrange(
191
- q_img,
192
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
193
- )
194
- k_img = rearrange(
195
- k_img,
196
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
197
- )
281
+ output_without_heads = rearrange(output_with_heads, "b t h d -> b t (h d)")
282
+
283
+ return self.out_proj(output_without_heads)
284
+ else:
285
+ # Divide Q/K/V into heads
286
+ q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
287
+ k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
288
+ v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
198
289
 
199
- # Re-combine the BOS tokens and the RoPE-enhanced image tokens
200
- q = torch.cat([q_bos, q_img], dim=1)
201
- k = torch.cat([k_bos, k_img], dim=1)
290
+ qk_scores = q @ k.transpose(-1, -2)
291
+
292
+ qk_scores *= self.scaling_factor
293
+
294
+ if self.talking_heads:
295
+ qk_scores = torch.einsum(
296
+ "b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
297
+ )
298
+
299
+ # Apply mask if causal (must come before softmax)
300
+ if self.causal:
301
+ qk_scores.masked_fill_(self.mask, float("-inf"))
302
+
303
+ qk_scores = F.softmax(qk_scores, dim=-1)
304
+
305
+ if self.talking_heads:
306
+ qk_scores = torch.einsum(
307
+ "b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
308
+ )
309
+
310
+ qk_scores = self.dropout(qk_scores)
311
+
312
+ output_with_heads = qk_scores @ v
313
+
314
+ output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
315
+
316
+ return self.out_proj(output_without_heads)
317
+
318
+ def attention_logits(self, q, k, v):
319
+
320
+ q, k, v = self.project_qkv(q, k, v)
202
321
 
203
322
  # Divide Q/K/V into heads
204
323
  q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
@@ -207,19 +326,24 @@ class MHAttention(nn.Module):
207
326
 
208
327
  qk_scores = q @ k.transpose(-1, -2)
209
328
 
210
- qk_scores /= math.sqrt(self.head_dim)
329
+ qk_scores *= self.scaling_factor
211
330
 
212
331
  # Apply mask if causal (must come before softmax)
213
332
  if self.causal:
214
333
  qk_scores.masked_fill_(self.mask, float("-inf"))
215
334
 
216
- qk_scores = F.softmax(qk_scores, dim=-1)
217
-
218
- output_with_heads = qk_scores @ v
335
+ return qk_scores # (batch, head, seq_len, seq_len)
219
336
 
220
- output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
221
-
222
- return self.out_proj(output_without_heads)
337
+ def reset_parameters(self):
338
+ # Default nn.Linear init is kaiming_uniform, which is fine
339
+ self.q_proj.reset_parameters()
340
+ self.k_proj.reset_parameters()
341
+ self.v_proj.reset_parameters()
342
+ self.out_proj.reset_parameters()
343
+ if self.talking_heads:
344
+ # Initialize close to identity
345
+ nn.init.eye_(self.head_projection.weight)
346
+ nn.init.eye_(self.sample_projection.weight)
223
347
 
224
348
 
225
349
  class FeedforwardBlock(nn.Module):
@@ -235,44 +359,123 @@ class FeedforwardBlock(nn.Module):
235
359
  activation=nn.ReLU,
236
360
  activation_kwargs=None,
237
361
  dropout=0.0,
238
- linear_module=nn.Linear,
362
+ inner_dropout=None,
363
+ outer_dropout=None,
364
+ linear_module_up=nn.Linear,
365
+ linear_module_down=nn.Linear,
239
366
  pre_norm=True,
240
367
  normformer=False,
241
- raw_input=False,
368
+ post_norm=True,
369
+ residual_path=True,
370
+ checkpoint=True,
242
371
  ):
243
372
  super().__init__()
244
373
 
374
+ self.checkpoint = checkpoint
375
+ self.residual_path = residual_path
376
+ self.post_norm = post_norm
377
+ self.xglu = activation.__name__.endswith("GLU")
378
+
379
+ if self.residual_path and (output_features < input_features):
380
+ raise ValueError(
381
+ "If the number of output features will be less than "
382
+ "the number of input features, then `residual_path` "
383
+ "should be set to False."
384
+ )
385
+
386
+ if self.post_norm:
387
+ self.layernorm = nn.LayerNorm(output_features)
388
+
245
389
  if activation_kwargs is not None:
246
390
  self.activation = activation(**activation_kwargs)
247
391
  else:
248
392
  self.activation = activation()
249
393
 
250
- if raw_input:
251
- self.memory_type = SpectralNormLinear
252
- else:
253
- self.memory_type = nn.Linear
254
-
255
- self.dropout = nn.Dropout(dropout)
394
+ self.inner_dropout = nn.Dropout(
395
+ inner_dropout if inner_dropout is not None else dropout
396
+ )
397
+ self.outer_dropout = nn.Dropout(
398
+ outer_dropout if outer_dropout is not None else dropout
399
+ )
256
400
 
257
401
  self.max_features = (
258
- 2 * ratio * output_features
259
- if activation.__name__.endswith("GLU")
260
- else ratio * output_features
402
+ 2 * int(ratio * output_features)
403
+ if self.xglu
404
+ else int(ratio * output_features)
405
+ )
406
+
407
+ self.linear_in = linear_module_up(input_features, self.max_features)
408
+ self.linear_out = linear_module_down(
409
+ int(ratio * output_features), output_features
261
410
  )
262
411
 
263
412
  self.process = nn.Sequential(
264
413
  *[
265
414
  nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
266
- linear_module(input_features, self.max_features),
415
+ self.linear_in,
267
416
  self.activation,
268
- nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
269
- self.memory_type(ratio * output_features, output_features),
270
- self.dropout,
417
+ self.inner_dropout,
418
+ (
419
+ nn.LayerNorm(int(ratio * output_features))
420
+ if normformer
421
+ else nn.Identity()
422
+ ),
423
+ self.linear_out,
424
+ self.outer_dropout,
271
425
  ]
272
426
  )
273
427
 
428
+ self.recycling_enabled = False
429
+ if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
430
+ self.linear_out, "column_recycling_rate"
431
+ ):
432
+ self.recycling_enabled = True
433
+ self.master_recycling_rate = self.linear_in.row_recycling_rate
434
+ self.linear_in.row_recycling_rate = 0.0
435
+ self.linear_out.column_recycling_rate = 0.0
436
+ if (
437
+ hasattr(self.linear_in, "column_recycling_rate")
438
+ and self.linear_in.column_recycling_rate > 0
439
+ ) or (
440
+ hasattr(self.linear_out, "row_recycling_rate")
441
+ and self.linear_out.row_recycling_rate > 0
442
+ ):
443
+ raise NotImplementedError(
444
+ "At the moment this layer can only support recycling linear "
445
+ "layers if the in layer resets only rows and the out layer "
446
+ "resets only columns."
447
+ )
448
+
449
+ self.reset_parameters()
450
+
274
451
  def forward(self, x):
275
- return self.process(x)
452
+
453
+ # Recycle weights if using recycling linear layers
454
+ if self.training and self.recycling_enabled:
455
+ indices = self.linear_out.get_reset_indices(1)
456
+ self.linear_in.reset_rows(indices, incoming_data=x)
457
+ self.linear_out.reset_columns(indices)
458
+
459
+ if self.checkpoint:
460
+ processed = checkpoint(self.process, x, use_reentrant=False)
461
+ else:
462
+ processed = self.process(x)
463
+
464
+ if self.residual_path and self.post_norm:
465
+ return self.layernorm(x + processed)
466
+ elif self.residual_path:
467
+ return x + processed
468
+ else:
469
+ return processed
470
+
471
+ def reset_parameters(self):
472
+ if self.post_norm:
473
+ self.layernorm.reset_parameters()
474
+
475
+ # Iterate over the sequential block to reset parameters
476
+ for module in self.process:
477
+ if hasattr(module, "reset_parameters"):
478
+ module.reset_parameters()
276
479
 
277
480
 
278
481
  class TransformerBlock(nn.Module):
@@ -289,30 +492,57 @@ class TransformerBlock(nn.Module):
289
492
  seq_len,
290
493
  d_model,
291
494
  n_heads,
292
- position_embedding_type="absolute", # absolute or relative
495
+ relative_position_embedding=False,
293
496
  source_size=None,
294
- bos_tokens=0,
497
+ utility_tokens=0,
498
+ talking_heads=False,
295
499
  mlp_ratio=4,
296
500
  activation: nn.Module = nn.ReLU,
297
501
  activation_kwargs: Optional[dict] = None,
298
- mlp_dropout=0.0,
502
+ ff_linear_module_up=None,
503
+ ff_linear_module_down=None,
504
+ msa_scaling="d",
505
+ ff_dropout=0.0,
506
+ ff_inner_dropout=0.0,
507
+ ff_outer_dropout=0.0,
299
508
  msa_dropout=0.0,
300
509
  identity_probability=0.0,
301
510
  causal=False,
302
511
  linear_module=nn.Linear,
303
512
  pre_norm=True,
513
+ post_norm=False,
304
514
  normformer=False,
515
+ checkpoint_ff=True,
516
+ layerscale=True,
305
517
  ):
518
+ """
519
+ Args:
520
+ msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
521
+ to mimic the original Attention is All You Need approach of
522
+ dividing by the sqrt of the embedding Dimension or "d" per
523
+ "Tensor Programs V...". Default "d"
524
+ """
525
+
306
526
  super().__init__()
307
527
 
308
528
  self.pre_norm = pre_norm
529
+ self.post_norm = post_norm
530
+ self.normformer = normformer
309
531
 
310
- self.identity_probability = identity_probability
532
+ self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
311
533
 
312
534
  self.layer_norm_1 = nn.LayerNorm(d_model)
313
535
  self.layer_norm_2 = nn.LayerNorm(d_model)
536
+ self.layer_norm_3 = nn.LayerNorm(d_model)
314
537
 
315
- if position_embedding_type == "relative":
538
+ if layerscale:
539
+ self.layerscale1 = LayerScale(d_model)
540
+ self.layerscale2 = LayerScale(d_model)
541
+ else:
542
+ self.layerscale1 = nn.Identity()
543
+ self.layerscale2 = nn.Identity()
544
+
545
+ if relative_position_embedding:
316
546
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
317
547
  if d_model < 16:
318
548
  dim = d_model
@@ -333,63 +563,87 @@ class TransformerBlock(nn.Module):
333
563
  linear_module=linear_module,
334
564
  rotary_embedding=self.rotary_embedding,
335
565
  source_size=source_size,
336
- bos_tokens=bos_tokens,
566
+ utility_tokens=utility_tokens,
567
+ talking_heads=talking_heads,
568
+ scaling=msa_scaling,
337
569
  )
338
570
 
339
- # Submodules for the feedforward process
571
+ # Submodule for the feedforward process
340
572
  self.ff = FeedforwardBlock(
341
573
  d_model,
342
574
  mlp_ratio,
343
575
  d_model,
344
576
  activation=activation,
345
577
  activation_kwargs=activation_kwargs,
346
- dropout=mlp_dropout,
347
- linear_module=linear_module,
348
- pre_norm=pre_norm,
578
+ dropout=ff_dropout,
579
+ inner_dropout=ff_inner_dropout,
580
+ outer_dropout=ff_outer_dropout,
581
+ linear_module_up=(
582
+ ff_linear_module_up
583
+ if ff_linear_module_up is not None
584
+ else linear_module
585
+ ),
586
+ linear_module_down=(
587
+ ff_linear_module_down
588
+ if ff_linear_module_down is not None
589
+ else linear_module
590
+ ),
591
+ pre_norm=False, # Handled outside the block
349
592
  normformer=normformer,
593
+ post_norm=False, # Handled outside the block
594
+ residual_path=False, # Handled outside the block
595
+ checkpoint=checkpoint_ff,
350
596
  )
351
597
 
598
+ self.reset_parameters()
599
+
352
600
  @property
353
601
  def _kv_distance(self) -> float:
354
602
  return self.attn._kv_distance
355
603
 
356
604
  def forward(self, x):
357
- if not self.training:
358
- identity_probability = 0.0
359
- else:
360
- identity_probability = self.identity_probability
361
-
362
- # perform the identity operation for some rows in the batch
363
- identity_count = random.binomial(n=x.size(0), p=identity_probability)
364
- shuffle_indices = torch.randperm(x.size(0), device=x.device)
365
- unshuffle_indices = torch.argsort(shuffle_indices)
366
- shuffled = x[shuffle_indices, :, :]
367
- identity_x = shuffled[:identity_count, :, :]
368
- process_x = shuffled[identity_count:, :, :]
369
605
 
370
606
  if self.pre_norm:
371
- norm_process_x = self.layer_norm_1(process_x)
372
- process_x = process_x + self.attn(
373
- norm_process_x, norm_process_x, norm_process_x
374
- )
375
- process_x = process_x + self.ff(process_x)
376
- else: # post-norm
377
- process_x = process_x + self.attn(process_x, process_x, process_x)
378
- norm_process_x = self.layer_norm_1(process_x)
379
- process_x = process_x + self.ff(process_x)
380
-
381
- # Always post norm as eventually we reach the classification head!
382
- x = self.layer_norm_2(
383
- torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
384
- )
607
+ x = self.layer_norm_1(x)
608
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
609
+ x = self.layer_norm_2(x)
610
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
611
+ if self.post_norm: # i.e. in addition! Pre and post.
612
+ x = self.layer_norm_3(x)
613
+ elif self.post_norm: # i.e. only, not prenorm, just post
614
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
615
+ x = self.layer_norm_1(x)
616
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
617
+ x = self.layer_norm_2(x)
618
+ else: # Not pre or post norm. Stand well back.
619
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
620
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
385
621
 
386
622
  return x
387
623
 
624
+ def attention_logits(self, x):
625
+ """
626
+ Give back the attention scores used in this layer.
627
+ """
628
+ if self.pre_norm:
629
+ x = self.layer_norm_1(x)
630
+ return self.attn.attention_logits(x, x, x)
631
+ else:
632
+ return self.attn.attention_logits(x, x, x)
633
+
634
+ def reset_parameters(self):
635
+ self.layer_norm_1.reset_parameters()
636
+ self.layer_norm_2.reset_parameters()
637
+ self.layer_norm_3.reset_parameters()
638
+
639
+ self.attn.reset_parameters()
640
+ self.ff.reset_parameters()
641
+
388
642
 
389
643
  class TransformerEncoder(nn.Module):
390
644
  """
391
645
  This assumes we already get a sequence of embeddings (e.g. word or image
392
- patch embeddings). It uses learned positional embeddings.
646
+ patch embeddings).
393
647
  """
394
648
 
395
649
  def __init__(
@@ -398,53 +652,93 @@ class TransformerEncoder(nn.Module):
398
652
  d_model,
399
653
  n_layers,
400
654
  n_heads,
401
- position_embedding_type="absolute", # absolute or relative
655
+ absolute_position_embedding=True,
656
+ relative_position_embedding=False,
402
657
  source_size=None,
403
658
  mlp_ratio=4,
404
659
  activation: nn.Module = nn.ReLU,
405
660
  activation_kwargs: Optional[dict] = None,
406
- mlp_dropout=0.0,
661
+ ff_linear_module_up=None,
662
+ ff_linear_module_down=None,
663
+ ff_dropout=0.0,
664
+ ff_inner_dropout=0.0,
665
+ ff_outer_dropout=0.0,
407
666
  msa_dropout=0.0,
408
667
  stochastic_depth=0.0,
409
668
  causal=False,
410
669
  linear_module=nn.Linear,
411
- bos_tokens=0,
412
- return_bos_tokens=False,
670
+ utility_tokens=0,
671
+ talking_heads=False,
672
+ return_utility_tokens=False,
413
673
  pre_norm=True,
674
+ post_norm=False,
414
675
  normformer=False,
676
+ msa_scaling="d",
677
+ checkpoint_ff=True,
678
+ layerscale=True,
415
679
  ):
416
- if position_embedding_type == "relative":
417
- assert source_size is not None # TODO: make this a proper exception
680
+ """
681
+ Args:
682
+ msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
683
+ to mimic the original Attention is All You Need approach of
684
+ dividing by the sqrt of the embedding Dimension or "d" per
685
+ "Tensor Programs V...". Default "d"
686
+ """
687
+
688
+ if relative_position_embedding and (source_size is None):
689
+ raise ValueError(
690
+ "`source_size` for TransformerEncoder cannot be None if"
691
+ " `relative_position_embedding` is True"
692
+ )
693
+
694
+ if absolute_position_embedding and (seq_len is None):
695
+ raise ValueError(
696
+ "`seq_len` for TransformerEncoder cannot be None if"
697
+ " `absolute_position_embedding` is True"
698
+ )
418
699
 
419
700
  super().__init__()
701
+
702
+ if FLASH_ATTN and talking_heads:
703
+ warnings.warn(
704
+ "Using talking heads currently prevents using flash attention.",
705
+ stacklevel=2,
706
+ )
707
+
420
708
  self.seq_len = seq_len
421
709
  self.n_heads = n_heads
422
- self._bos_tokens = bos_tokens
423
- self.return_bos_tokens = return_bos_tokens
424
-
425
- # Initialise BOS tokens with normal init, like usual Pytorch embeddings
426
- if self._bos_tokens:
427
- self._bos_embedding = nn.Parameter(torch.empty(self._bos_tokens, d_model))
428
- nn.init.normal_(self._bos_embedding, mean=0.0, std=1.0)
429
- self.full_sequence_length = self.seq_len + self._bos_tokens
710
+ self._utility_tokens = utility_tokens
711
+ self.return_utility_tokens = return_utility_tokens
712
+
713
+ # Initialise utility tokens with normal init, like usual Pytorch embeddings
714
+ if self._utility_tokens:
715
+ self._utility_token_embedding = nn.Parameter(
716
+ torch.empty(self._utility_tokens, d_model)
717
+ )
718
+ nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
719
+ else:
720
+ self._utility_token_embedding = None
721
+
722
+ if self._utility_tokens and (self.seq_len is not None):
723
+ self.full_sequence_length = self.seq_len + self._utility_tokens
430
724
  else:
431
- self._bos_embedding = None
432
725
  self.full_sequence_length = self.seq_len
433
726
 
434
727
  self.d_model = d_model
435
728
 
436
- self.position_embedding_type = position_embedding_type
437
-
438
- if self.position_embedding_type == "absolute":
729
+ if absolute_position_embedding:
439
730
  self.absolute_position_embedding = nn.Embedding(
440
731
  self.full_sequence_length, d_model
441
732
  )
733
+ else:
734
+ self.absolute_position_embedding = None
442
735
 
443
- self.mlp_dropout = mlp_dropout
736
+ self.mlp_dropout = ff_dropout
444
737
  self.msa_dropout = msa_dropout
445
738
  self.stochastic_depth = stochastic_depth
446
739
 
447
- assert isinstance(n_layers, int) # XXX: make this a proper Exception
740
+ assert isinstance(n_layers, int)
741
+
448
742
  if n_layers == 1:
449
743
  self.stochastic_depth_probabilities = [0.0]
450
744
  else:
@@ -459,35 +753,48 @@ class TransformerEncoder(nn.Module):
459
753
  self.full_sequence_length,
460
754
  d_model,
461
755
  n_heads,
462
- position_embedding_type=position_embedding_type,
756
+ relative_position_embedding=relative_position_embedding,
463
757
  source_size=source_size,
464
- bos_tokens=bos_tokens,
758
+ utility_tokens=utility_tokens,
759
+ talking_heads=talking_heads,
465
760
  mlp_ratio=mlp_ratio,
466
761
  activation=activation,
467
762
  activation_kwargs=activation_kwargs,
468
- mlp_dropout=mlp_dropout,
763
+ ff_linear_module_up=ff_linear_module_up,
764
+ ff_linear_module_down=ff_linear_module_down,
765
+ msa_scaling=msa_scaling,
766
+ ff_dropout=ff_dropout,
767
+ ff_inner_dropout=ff_inner_dropout,
768
+ ff_outer_dropout=ff_outer_dropout,
469
769
  msa_dropout=msa_dropout,
470
770
  identity_probability=self.stochastic_depth_probabilities[i],
471
771
  causal=causal,
472
772
  linear_module=linear_module,
473
773
  pre_norm=pre_norm,
774
+ post_norm=post_norm,
474
775
  normformer=normformer,
776
+ checkpoint_ff=checkpoint_ff,
777
+ layerscale=layerscale,
475
778
  )
476
779
  for i in range(n_layers)
477
780
  ]
478
781
  )
479
782
 
783
+ self.reset_parameters()
784
+
480
785
  @property
481
786
  def _kv_distances(self) -> float:
482
787
  return ",".join([str(block._kv_distance) for block in self.blocks])
483
788
 
484
- def forward(self, x):
485
- if self._bos_tokens:
486
- x = torch.cat([self._bos_embedding.expand(x.size(0), -1, -1), x], dim=1)
789
+ def preprocess(self, x):
790
+ if self._utility_tokens:
791
+ x = torch.cat(
792
+ [self._utility_token_embedding.expand(x.size(0), -1, -1), x], dim=1
793
+ )
487
794
  else:
488
795
  x = x
489
796
 
490
- if self.position_embedding_type == "absolute":
797
+ if self.absolute_position_embedding is not None:
491
798
  x = x + self.absolute_position_embedding(
492
799
  torch.arange(
493
800
  0, self.full_sequence_length, dtype=torch.long, device=x.device
@@ -496,10 +803,40 @@ class TransformerEncoder(nn.Module):
496
803
  ) # to shape (1, seq_len) to broadcast over batch
497
804
  )
498
805
 
806
+ return x
807
+
808
+ def forward(self, x):
809
+
810
+ x = self.preprocess(x)
811
+
499
812
  for block in self.blocks:
500
813
  x = block(x)
501
814
 
502
- if self._bos_tokens and not self.return_bos_tokens:
503
- return x[:, self._bos_tokens :, :]
815
+ if self._utility_tokens and not self.return_utility_tokens:
816
+ return x[:, self._utility_tokens :, :]
504
817
  else:
505
818
  return x
819
+
820
+ def attention_logits(self, x):
821
+
822
+ x = self.preprocess(x)
823
+
824
+ layer_scores = []
825
+
826
+ for block in self.blocks:
827
+ # Get attention scores with shape (batch, 1, head, seq_len, seq_len)
828
+ layer_attention_logits = block.attention_logits(x).unsqueeze(1)
829
+ layer_scores.append(layer_attention_logits)
830
+ x = block(x)
831
+
832
+ return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
833
+
834
+ def reset_parameters(self):
835
+ if self._utility_token_embedding is not None:
836
+ nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
837
+
838
+ if self.absolute_position_embedding is not None:
839
+ self.absolute_position_embedding.reset_parameters()
840
+
841
+ for block in self.blocks:
842
+ block.reset_parameters()