broccoli-ml 6.0.0__py3-none-any.whl → 13.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/activation.py CHANGED
@@ -46,10 +46,7 @@ class GELU(nn.Module):
46
46
 
47
47
  class Swish(nn.Module):
48
48
  """
49
- Implementation of (beta) SwiGLU, as introduced in "GLU Variants Improve Transformer"
50
- (https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
51
-
52
- Halves the incoming parameter count, which should be scaled up before input.
49
+ Implementation of (beta) Swish
53
50
  """
54
51
 
55
52
  def __init__(self) -> None:
broccoli/linear.py CHANGED
@@ -1,4 +1,8 @@
1
1
  import math
2
+ import random
3
+ import warnings
4
+ from typing import Union, List, Iterable
5
+
2
6
  import torch
3
7
  from torch import nn
4
8
  from torch.nn import functional as F
@@ -136,3 +140,229 @@ class WeightNormedLinear(nn.Module):
136
140
  f"WeightNormedLinear(in_features={self.in_features},"
137
141
  f"out_features={self.out_features}, bias={self.use_bias})"
138
142
  )
143
+
144
+
145
+ class RecyclingLinear(nn.Module):
146
+ def __init__(
147
+ self,
148
+ in_features: int,
149
+ out_features: int,
150
+ bias: bool = True,
151
+ row_recycling_rate: float = 0.0,
152
+ column_recycling_rate: float = 0.0,
153
+ adaptive=False,
154
+ xglu=False,
155
+ ):
156
+ super().__init__()
157
+ self.in_features = in_features
158
+ self.out_features = out_features
159
+ self.bias = bias
160
+ self.xglu = xglu
161
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
162
+ self.row_recycling_rate = row_recycling_rate
163
+ self.column_recycling_rate = column_recycling_rate
164
+ self.adaptive = adaptive
165
+ self.optimisers = []
166
+ self.initial_learning_rates = []
167
+ self._warned_about_registration = False
168
+
169
+ def register_optimiser(self, optimiser: torch.optim.Optimizer):
170
+ self.optimisers.append(optimiser)
171
+ self.initial_learning_rates.append(self._get_learning_rate(optimiser))
172
+ if self.initial_learning_rates[-1] == 0.0:
173
+ warnings.warn(
174
+ "Learning rate of registered optimiser was 0.0 - make sure "
175
+ "you haven't initialised a scheduler before registering the "
176
+ "optimiser",
177
+ stacklevel=2,
178
+ )
179
+
180
+ def _get_learning_rate(self, optimiser: torch.optim.Optimizer):
181
+ for group in optimiser.param_groups:
182
+ for param in group["params"]:
183
+ if param is self.linear.weight:
184
+ return group["lr"]
185
+
186
+ def _get_multiplier(self):
187
+ if not self.adaptive or not self.optimisers:
188
+ return 1.0
189
+ else:
190
+ init = self.initial_learning_rates
191
+ current = [self._get_learning_rate(o) for o in self.optimisers]
192
+ pairs = zip(current, init, strict=True)
193
+ multipliers = [a / b for a, b in pairs if b != 0.0]
194
+ return min(multipliers) if multipliers else 0.0
195
+
196
+ def reset_rows(self, indices, incoming_data=None):
197
+ """
198
+ Resets rows.
199
+ If incoming_data is provided, resets to the centroid (mean) of that data.
200
+ If not, resets to the mean of existing weights.
201
+ """
202
+ if not torch.is_tensor(indices):
203
+ idx_tensor = torch.as_tensor(
204
+ list(indices), dtype=torch.long, device=self.linear.weight.device
205
+ )
206
+ else:
207
+ idx_tensor = indices
208
+
209
+ if idx_tensor.numel() == 0:
210
+ return
211
+
212
+ if incoming_data is not None:
213
+ target_center = self._mean_input_weights(incoming_data)
214
+ else:
215
+ target_center = self._mean_value_weights()
216
+
217
+ target_center = target_center.expand(idx_tensor.size(0), -1)
218
+
219
+ if self.xglu:
220
+ gate_indices = idx_tensor
221
+ value_indices = idx_tensor + (self.linear.out_features // 2)
222
+ self._update_weights(gate_indices, 0, target_center, self.optimisers)
223
+ self._update_weights(value_indices, 0, target_center, self.optimisers)
224
+ else:
225
+ self._update_weights(idx_tensor, 0, target_center, self.optimisers)
226
+
227
+ def reset_columns(self, indices):
228
+ if not torch.is_tensor(indices):
229
+ idx_tensor = torch.as_tensor(
230
+ list(indices), dtype=torch.long, device=self.linear.weight.device
231
+ )
232
+ else:
233
+ idx_tensor = indices
234
+
235
+ if idx_tensor.size(0):
236
+ random_weights = self._random_weights(
237
+ self.linear.weight.size(0), indices.size(0)
238
+ )
239
+ # Make random col weights quiet so they don't introduce loud noise...
240
+ # ...but not so quiet that FP16 zeros them and ruins symmetry breaking!
241
+ random_weights *= 0.1
242
+ self._update_weights(indices, 1, random_weights, self.optimisers) # dim
243
+ else:
244
+ return
245
+
246
+ def forward(self, x):
247
+ if self.training and self.optimisers:
248
+ self.reset_rows(self.get_reset_indices(0))
249
+ self.reset_columns(self.get_reset_indices(1))
250
+ elif self.training and not self._warned_about_registration:
251
+ warnings.warn(
252
+ "RecyclingLinear: No optimiser registered. Recycling disabled.",
253
+ stacklevel=2,
254
+ )
255
+ self._warned_about_registration = True
256
+
257
+ return self.linear(x)
258
+
259
+ def get_reset_indices(self, dim):
260
+ base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
261
+ p = base_rate * self._get_multiplier()
262
+ if dim == 0:
263
+ if self.xglu:
264
+ sample_space = self.linear.out_features // 2
265
+ else:
266
+ sample_space = self.linear.out_features
267
+ elif dim == 1:
268
+ sample_space = self.linear.in_features
269
+ else:
270
+ raise ValueError("`dim` must be 0 or 1")
271
+
272
+ # Sample the indices
273
+ probs = torch.rand(sample_space, device=self.linear.weight.device)
274
+ mask = probs < p
275
+ if mask.any():
276
+ return torch.nonzero(mask).squeeze(-1)
277
+ else:
278
+ return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
279
+
280
+ def _random_weights(self, rows, columns):
281
+ device = self.linear.weight.device
282
+ weights = self.linear.weight.data
283
+ stdv = 1.0 / math.sqrt(weights.size(1))
284
+ random_weights = torch.rand(rows, columns, device=device)
285
+ random_weights -= 0.5 # Range [-0.5, +0.5]
286
+ random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
287
+ return random_weights
288
+
289
+ def _mean_input_weights(self, input):
290
+ reduce_dims = list(range(input.ndim - 1))
291
+ data_mean = input.detach().mean(dim=reduce_dims, keepdim=True)
292
+
293
+ weights = self.linear.weight.data
294
+ stdv = 1.0 / math.sqrt(weights.size(1))
295
+ data_norm = data_mean.std() + 1e-6
296
+ scale_factor = stdv / data_norm
297
+
298
+ return data_mean * scale_factor
299
+
300
+ def _mean_value_weights(self):
301
+ """
302
+ Only used when self.xglu
303
+ """
304
+ weights = self.linear.weight.data
305
+ rows = weights.size(0)
306
+ if self.xglu:
307
+ return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
308
+ else:
309
+ return self.linear.weight.data.mean(dim=0, keepdim=True)
310
+
311
+ def _mean_gate_weights(self):
312
+ """
313
+ Only used when self.xglu
314
+ """
315
+ weights = self.linear.weight.data
316
+ rows = weights.size(0)
317
+ return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
318
+
319
+ def _update_weights(
320
+ self,
321
+ indices: Iterable[int],
322
+ dim: int,
323
+ data: torch.Tensor,
324
+ optimisers: Union[
325
+ List[torch.optim.Optimizer], torch.optim.Optimizer, None
326
+ ] = None,
327
+ ):
328
+ if optimisers is None:
329
+ optimisers = []
330
+ if not isinstance(optimisers, list):
331
+ optimisers = [optimisers]
332
+
333
+ if not torch.is_tensor(indices):
334
+ idx_tensor = torch.as_tensor(
335
+ list(indices), dtype=torch.long, device=self.linear.weight.device
336
+ )
337
+ else:
338
+ idx_tensor = indices
339
+
340
+ if idx_tensor.numel() == 0:
341
+ return
342
+
343
+ with torch.no_grad():
344
+ if dim == 0:
345
+ self.linear.weight.data[idx_tensor] = data
346
+ elif dim == 1:
347
+ self.linear.weight.data[:, idx_tensor] = data
348
+ else:
349
+ raise ValueError("`dim` must be 0 or 1")
350
+
351
+ self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
352
+
353
+ def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
354
+ """
355
+ Zeroes out the optimizer state for the given indices in a single operation.
356
+ """
357
+ for optimiser in optimisers:
358
+ if param not in optimiser.state:
359
+ continue
360
+ state = optimiser.state[param]
361
+
362
+ for _, buffer in state.items():
363
+ if torch.is_tensor(buffer) and buffer.shape == param.shape:
364
+ # Vectorized zeroing
365
+ if dim == 0:
366
+ buffer[idx_tensor] = 0.0
367
+ else:
368
+ buffer[:, idx_tensor] = 0.0
broccoli/transformer.py CHANGED
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import math
2
3
  from typing import Optional, Tuple
3
4
 
@@ -13,12 +14,19 @@ from .rope import RotaryEmbedding, apply_rotary_emb
13
14
  try:
14
15
  from flash_attn import flash_attn_func
15
16
 
17
+ print("Using flash-attn.")
16
18
  FLASH_ATTN = True
17
19
  except ImportError:
18
20
  pass
19
21
  FLASH_ATTN = False
20
22
 
21
23
 
24
+ def scale_parameters(torch_module: nn.Module, factor: float):
25
+ with torch.no_grad():
26
+ for param in torch_module.parameters():
27
+ param.mul_(factor)
28
+
29
+
22
30
  def drop_path(
23
31
  x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
24
32
  ):
@@ -77,9 +85,11 @@ class MHAttention(nn.Module):
77
85
  seq_len=None,
78
86
  linear_module: nn.Module = nn.Linear,
79
87
  utility_tokens=0,
88
+ talking_heads=False,
80
89
  rotary_embedding=None,
81
90
  source_size=None,
82
91
  scaling="d",
92
+ beta=1.0,
83
93
  ):
84
94
  """
85
95
  Args:
@@ -95,10 +105,20 @@ class MHAttention(nn.Module):
95
105
  if causal:
96
106
  assert seq_len is not None
97
107
 
108
+ self.talking_heads = talking_heads
109
+
110
+ if self.talking_heads:
111
+ self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
112
+ self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
113
+ else:
114
+ self.head_projection = None
115
+ self.sample_projection = None
116
+
98
117
  self.embed_dim = embed_dim
99
118
  self.n_heads = n_heads
100
119
  assert embed_dim % n_heads == 0
101
120
  self.scaling = scaling
121
+ self.beta = beta
102
122
 
103
123
  self.head_dim = self.embed_dim // self.n_heads
104
124
 
@@ -180,17 +200,26 @@ class MHAttention(nn.Module):
180
200
  "`source_size` must be a tuple of 1, 2 or 3 integers"
181
201
  )
182
202
 
183
- q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
184
- k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
203
+ q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
204
+ k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
205
+
206
+ q_util, q_img = (
207
+ q[:, : self.utility_tokens, :, :],
208
+ q[:, self.utility_tokens :, :, :],
209
+ )
210
+ k_util, k_img = (
211
+ k[:, : self.utility_tokens, :, :],
212
+ k[:, self.utility_tokens :, :, :],
213
+ )
185
214
 
186
215
  q_img = rearrange(
187
216
  q_img,
188
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
217
+ f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
189
218
  **spatial_dimension_values,
190
219
  )
191
220
  k_img = rearrange(
192
221
  k_img,
193
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
222
+ f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
194
223
  **spatial_dimension_values,
195
224
  )
196
225
 
@@ -201,17 +230,20 @@ class MHAttention(nn.Module):
201
230
 
202
231
  q_img = rearrange(
203
232
  q_img,
204
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
233
+ f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
205
234
  )
206
235
  k_img = rearrange(
207
236
  k_img,
208
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
237
+ f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
209
238
  )
210
239
 
211
240
  # Re-combine the utility tokens and the RoPE-enhanced sequence tokens
212
241
  q = torch.cat([q_util, q_img], dim=1)
213
242
  k = torch.cat([k_util, k_img], dim=1)
214
243
 
244
+ q = rearrange(q, "b t h d -> b t (h d)")
245
+ k = rearrange(k, "b t h d -> b t (h d)")
246
+
215
247
  return q, k
216
248
 
217
249
  def project_qkv(
@@ -242,7 +274,7 @@ class MHAttention(nn.Module):
242
274
 
243
275
  q, k, v = self.project_qkv(q, k, v)
244
276
 
245
- if FLASH_ATTN:
277
+ if FLASH_ATTN and not self.talking_heads:
246
278
  # Divide Q/K/V into heads
247
279
  q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
248
280
  k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
@@ -270,12 +302,22 @@ class MHAttention(nn.Module):
270
302
 
271
303
  qk_scores *= self.scaling_factor
272
304
 
305
+ if self.talking_heads:
306
+ qk_scores = torch.einsum(
307
+ "b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
308
+ )
309
+
273
310
  # Apply mask if causal (must come before softmax)
274
311
  if self.causal:
275
312
  qk_scores.masked_fill_(self.mask, float("-inf"))
276
313
 
277
314
  qk_scores = F.softmax(qk_scores, dim=-1)
278
315
 
316
+ if self.talking_heads:
317
+ qk_scores = torch.einsum(
318
+ "b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
319
+ )
320
+
279
321
  qk_scores = self.dropout(qk_scores)
280
322
 
281
323
  output_with_heads = qk_scores @ v
@@ -308,7 +350,14 @@ class MHAttention(nn.Module):
308
350
  self.q_proj.reset_parameters()
309
351
  self.k_proj.reset_parameters()
310
352
  self.v_proj.reset_parameters()
353
+ scale_parameters(self.v_proj, self.beta) # per Microsoft DeepNet
311
354
  self.out_proj.reset_parameters()
355
+ scale_parameters(self.out_proj, self.beta) # per Microsoft DeepNet
356
+
357
+ if self.talking_heads:
358
+ # Initialize close to identity
359
+ nn.init.eye_(self.head_projection.weight)
360
+ nn.init.eye_(self.sample_projection.weight)
312
361
 
313
362
 
314
363
  class FeedforwardBlock(nn.Module):
@@ -324,19 +373,19 @@ class FeedforwardBlock(nn.Module):
324
373
  activation=nn.ReLU,
325
374
  activation_kwargs=None,
326
375
  dropout=0.0,
376
+ inner_dropout=None,
377
+ outer_dropout=None,
327
378
  linear_module_up=nn.Linear,
328
379
  linear_module_down=nn.Linear,
329
- pre_norm=True,
330
380
  normformer=False,
331
- post_norm=True,
332
- residual_path=True,
333
381
  checkpoint=True,
382
+ beta=1.0,
334
383
  ):
335
384
  super().__init__()
336
385
 
337
386
  self.checkpoint = checkpoint
338
- self.residual_path = residual_path
339
- self.post_norm = post_norm
387
+ self.beta = beta
388
+ self.xglu = activation.__name__.endswith("GLU")
340
389
 
341
390
  if self.residual_path and (output_features < input_features):
342
391
  raise ValueError(
@@ -353,40 +402,76 @@ class FeedforwardBlock(nn.Module):
353
402
  else:
354
403
  self.activation = activation()
355
404
 
356
- self.dropout = nn.Dropout(dropout)
405
+ self.inner_dropout = nn.Dropout(
406
+ inner_dropout if inner_dropout is not None else dropout
407
+ )
408
+ self.outer_dropout = nn.Dropout(
409
+ outer_dropout if outer_dropout is not None else dropout
410
+ )
357
411
 
358
412
  self.max_features = (
359
- 2 * ratio * output_features
360
- if activation.__name__.endswith("GLU")
361
- else ratio * output_features
413
+ 2 * int(ratio * output_features)
414
+ if self.xglu
415
+ else int(ratio * output_features)
416
+ )
417
+
418
+ self.linear_in = linear_module_up(input_features, self.max_features)
419
+ self.linear_out = linear_module_down(
420
+ int(ratio * output_features), output_features
362
421
  )
363
422
 
364
423
  self.process = nn.Sequential(
365
424
  *[
366
- nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
367
- linear_module_up(input_features, self.max_features),
425
+ self.linear_in,
368
426
  self.activation,
369
- nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
370
- linear_module_down(ratio * output_features, output_features),
371
- self.dropout,
427
+ self.inner_dropout,
428
+ (
429
+ nn.LayerNorm(int(ratio * output_features))
430
+ if normformer
431
+ else nn.Identity()
432
+ ),
433
+ self.linear_out,
434
+ self.outer_dropout,
372
435
  ]
373
436
  )
374
437
 
438
+ self.recycling_enabled = False
439
+ if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
440
+ self.linear_out, "column_recycling_rate"
441
+ ):
442
+ self.recycling_enabled = True
443
+ self.master_recycling_rate = self.linear_in.row_recycling_rate
444
+ self.linear_in.row_recycling_rate = 0.0
445
+ self.linear_out.column_recycling_rate = 0.0
446
+ if (
447
+ hasattr(self.linear_in, "column_recycling_rate")
448
+ and self.linear_in.column_recycling_rate > 0
449
+ ) or (
450
+ hasattr(self.linear_out, "row_recycling_rate")
451
+ and self.linear_out.row_recycling_rate > 0
452
+ ):
453
+ raise NotImplementedError(
454
+ "At the moment this layer can only support recycling linear "
455
+ "layers if the in layer resets only rows and the out layer "
456
+ "resets only columns."
457
+ )
458
+
375
459
  self.reset_parameters()
376
460
 
377
461
  def forward(self, x):
378
462
 
463
+ # Recycle weights if using recycling linear layers
464
+ if self.training and self.recycling_enabled:
465
+ indices = self.linear_out.get_reset_indices(1)
466
+ self.linear_in.reset_rows(indices, incoming_data=x)
467
+ self.linear_out.reset_columns(indices)
468
+
379
469
  if self.checkpoint:
380
470
  processed = checkpoint(self.process, x, use_reentrant=False)
381
471
  else:
382
472
  processed = self.process(x)
383
473
 
384
- if self.residual_path and self.post_norm:
385
- return self.layernorm(x + processed)
386
- elif self.residual_path:
387
- return x + processed
388
- else:
389
- return processed
474
+ return processed
390
475
 
391
476
  def reset_parameters(self):
392
477
  if self.post_norm:
@@ -397,8 +482,11 @@ class FeedforwardBlock(nn.Module):
397
482
  if hasattr(module, "reset_parameters"):
398
483
  module.reset_parameters()
399
484
 
485
+ scale_parameters(self.linear_in, self.beta) # per Microsoft DeepNet
486
+ scale_parameters(self.linear_out, self.beta)
487
+
400
488
 
401
- class TransformerBlock(nn.Module):
489
+ class EncoderBlock(nn.Module):
402
490
  """
403
491
  Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
404
492
  which is also what is seen in e.g.
@@ -415,13 +503,16 @@ class TransformerBlock(nn.Module):
415
503
  relative_position_embedding=False,
416
504
  source_size=None,
417
505
  utility_tokens=0,
506
+ talking_heads=False,
418
507
  mlp_ratio=4,
419
508
  activation: nn.Module = nn.ReLU,
420
509
  activation_kwargs: Optional[dict] = None,
421
510
  ff_linear_module_up=None,
422
511
  ff_linear_module_down=None,
423
512
  msa_scaling="d",
424
- mlp_dropout=0.0,
513
+ ff_dropout=0.0,
514
+ ff_inner_dropout=0.0,
515
+ ff_outer_dropout=0.0,
425
516
  msa_dropout=0.0,
426
517
  identity_probability=0.0,
427
518
  causal=False,
@@ -430,6 +521,8 @@ class TransformerBlock(nn.Module):
430
521
  post_norm=False,
431
522
  normformer=False,
432
523
  checkpoint_ff=True,
524
+ alpha=1.0,
525
+ beta=1.0,
433
526
  ):
434
527
  """
435
528
  Args:
@@ -441,15 +534,29 @@ class TransformerBlock(nn.Module):
441
534
 
442
535
  super().__init__()
443
536
 
537
+ if pre_norm and post_norm:
538
+ raise ValueError("A transformer cannot be both prenorm and postnorm.")
539
+
444
540
  self.pre_norm = pre_norm
445
541
  self.post_norm = post_norm
446
542
  self.normformer = normformer
447
543
 
544
+ self.alpha = alpha
545
+ self.beta = beta
546
+
448
547
  self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
449
548
 
450
- self.layer_norm_1 = nn.LayerNorm(d_model)
451
- self.layer_norm_2 = nn.LayerNorm(d_model)
452
- self.layer_norm_3 = nn.LayerNorm(d_model)
549
+ if self.pre_norm:
550
+ self.pre_attention_norm = nn.LayerNorm(d_model)
551
+ self.pre_mlp_norm = nn.LayerNorm(d_model)
552
+
553
+ if normformer:
554
+ self.normformer_norm = nn.LayerNorm(d_model)
555
+
556
+ if self.post_norm:
557
+ self.input_norm = nn.LayerNorm(d_model)
558
+ self.post_attention_norm = nn.LayerNorm(d_model)
559
+ self.post_mlp_norm = nn.LayerNorm(d_model)
453
560
 
454
561
  if relative_position_embedding:
455
562
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
@@ -473,7 +580,9 @@ class TransformerBlock(nn.Module):
473
580
  rotary_embedding=self.rotary_embedding,
474
581
  source_size=source_size,
475
582
  utility_tokens=utility_tokens,
583
+ talking_heads=talking_heads,
476
584
  scaling=msa_scaling,
585
+ beta=beta,
477
586
  )
478
587
 
479
588
  # Submodule for the feedforward process
@@ -483,7 +592,9 @@ class TransformerBlock(nn.Module):
483
592
  d_model,
484
593
  activation=activation,
485
594
  activation_kwargs=activation_kwargs,
486
- dropout=mlp_dropout,
595
+ dropout=ff_dropout,
596
+ inner_dropout=ff_inner_dropout,
597
+ outer_dropout=ff_outer_dropout,
487
598
  linear_module_up=(
488
599
  ff_linear_module_up
489
600
  if ff_linear_module_up is not None
@@ -494,11 +605,9 @@ class TransformerBlock(nn.Module):
494
605
  if ff_linear_module_down is not None
495
606
  else linear_module
496
607
  ),
497
- pre_norm=False, # Handled outside the block
498
608
  normformer=normformer,
499
- post_norm=False, # Handled outside the block
500
- residual_path=False, # Handled outside the block
501
609
  checkpoint=checkpoint_ff,
610
+ beta=beta,
502
611
  )
503
612
 
504
613
  self.reset_parameters()
@@ -508,22 +617,34 @@ class TransformerBlock(nn.Module):
508
617
  return self.attn._kv_distance
509
618
 
510
619
  def forward(self, x):
620
+ if self.post_norm:
621
+ x = self.input_norm(x)
511
622
 
512
623
  if self.pre_norm:
513
- x = self.layer_norm_1(x)
514
- x = x + self.drop_path(self.attn(x, x, x))
515
- x = self.layer_norm_2(x)
516
- x = x + self.drop_path(self.ff(x))
517
- if self.post_norm: # i.e. in addition! Pre and post.
518
- x = self.layer_norm_3(x)
519
- elif self.post_norm: # i.e. only, not prenorm, just post
520
- x = x + self.drop_path(self.attn(x, x, x))
521
- x = self.layer_norm_1(x)
522
- x = x + self.drop_path(self.ff(x))
523
- x = self.layer_norm_2(x)
524
- else: # Not pre or post norm. Stand well back.
525
- x = x + self.drop_path(self.attn(x, x, x))
526
- x = x + self.drop_path(self.ff(x))
624
+ process_x = self.pre_attention_norm(x)
625
+ else:
626
+ process_x = x
627
+
628
+ processed = self.drop_path(self.attn(process_x, process_x, process_x))
629
+
630
+ if self.normformer:
631
+ processed = self.normformer_norm(processed)
632
+
633
+ x = self.alpha * x + processed
634
+
635
+ if self.post_norm:
636
+ x = self.post_attention_norm(x)
637
+ elif self.pre_norm:
638
+ process_x = self.pre_mlp_norm(x)
639
+ else:
640
+ process_x = x
641
+
642
+ processed = self.drop_path(self.ff(process_x))
643
+
644
+ x = self.alpha * x + processed
645
+
646
+ if self.post_norm:
647
+ x = self.post_mlp_norm(x)
527
648
 
528
649
  return x
529
650
 
@@ -531,16 +652,26 @@ class TransformerBlock(nn.Module):
531
652
  """
532
653
  Give back the attention scores used in this layer.
533
654
  """
655
+ # Fix: Use the correct attribute name 'pre_attention_norm'
534
656
  if self.pre_norm:
535
- x = self.layer_norm_1(x)
657
+ # We must normalize the input before measuring attention logits
658
+ # to match what the model actually sees during forward()
659
+ x = self.pre_attention_norm(x)
536
660
  return self.attn.attention_logits(x, x, x)
537
661
  else:
538
662
  return self.attn.attention_logits(x, x, x)
539
663
 
540
664
  def reset_parameters(self):
541
- self.layer_norm_1.reset_parameters()
542
- self.layer_norm_2.reset_parameters()
543
- self.layer_norm_3.reset_parameters()
665
+ if self.pre_norm:
666
+ self.pre_attention_norm.reset_parameters()
667
+ self.pre_mlp_norm.reset_parameters()
668
+
669
+ if self.post_norm:
670
+ self.post_attention_norm.reset_parameters()
671
+ self.post_mlp_norm.reset_parameters()
672
+
673
+ if self.normformer:
674
+ self.normformer_norm.reset_parameters()
544
675
 
545
676
  self.attn.reset_parameters()
546
677
  self.ff.reset_parameters()
@@ -566,18 +697,23 @@ class TransformerEncoder(nn.Module):
566
697
  activation_kwargs: Optional[dict] = None,
567
698
  ff_linear_module_up=None,
568
699
  ff_linear_module_down=None,
569
- mlp_dropout=0.0,
700
+ ff_dropout=0.0,
701
+ ff_inner_dropout=0.0,
702
+ ff_outer_dropout=0.0,
570
703
  msa_dropout=0.0,
571
704
  stochastic_depth=0.0,
572
705
  causal=False,
573
706
  linear_module=nn.Linear,
574
707
  utility_tokens=0,
708
+ talking_heads=False,
575
709
  return_utility_tokens=False,
576
710
  pre_norm=True,
577
711
  post_norm=False,
578
712
  normformer=False,
579
713
  msa_scaling="d",
580
714
  checkpoint_ff=True,
715
+ alpha=1.0,
716
+ beta=1.0,
581
717
  ):
582
718
  """
583
719
  Args:
@@ -590,10 +726,23 @@ class TransformerEncoder(nn.Module):
590
726
  if relative_position_embedding and (source_size is None):
591
727
  raise ValueError(
592
728
  "`source_size` for TransformerEncoder cannot be None if"
593
- " `position_embedding_type` is relative"
729
+ " `relative_position_embedding` is True"
730
+ )
731
+
732
+ if absolute_position_embedding and (seq_len is None):
733
+ raise ValueError(
734
+ "`seq_len` for TransformerEncoder cannot be None if"
735
+ " `absolute_position_embedding` is True"
594
736
  )
595
737
 
596
738
  super().__init__()
739
+
740
+ if FLASH_ATTN and talking_heads:
741
+ warnings.warn(
742
+ "Using talking heads currently prevents using flash attention.",
743
+ stacklevel=2,
744
+ )
745
+
597
746
  self.seq_len = seq_len
598
747
  self.n_heads = n_heads
599
748
  self._utility_tokens = utility_tokens
@@ -605,9 +754,12 @@ class TransformerEncoder(nn.Module):
605
754
  torch.empty(self._utility_tokens, d_model)
606
755
  )
607
756
  nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
608
- self.full_sequence_length = self.seq_len + self._utility_tokens
609
757
  else:
610
758
  self._utility_token_embedding = None
759
+
760
+ if self._utility_tokens and (self.seq_len is not None):
761
+ self.full_sequence_length = self.seq_len + self._utility_tokens
762
+ else:
611
763
  self.full_sequence_length = self.seq_len
612
764
 
613
765
  self.d_model = d_model
@@ -619,7 +771,7 @@ class TransformerEncoder(nn.Module):
619
771
  else:
620
772
  self.absolute_position_embedding = None
621
773
 
622
- self.mlp_dropout = mlp_dropout
774
+ self.mlp_dropout = ff_dropout
623
775
  self.msa_dropout = msa_dropout
624
776
  self.stochastic_depth = stochastic_depth
625
777
 
@@ -635,20 +787,23 @@ class TransformerEncoder(nn.Module):
635
787
 
636
788
  self.blocks = nn.ModuleList(
637
789
  [
638
- TransformerBlock(
790
+ EncoderBlock(
639
791
  self.full_sequence_length,
640
792
  d_model,
641
793
  n_heads,
642
794
  relative_position_embedding=relative_position_embedding,
643
795
  source_size=source_size,
644
796
  utility_tokens=utility_tokens,
797
+ talking_heads=talking_heads,
645
798
  mlp_ratio=mlp_ratio,
646
799
  activation=activation,
647
800
  activation_kwargs=activation_kwargs,
648
801
  ff_linear_module_up=ff_linear_module_up,
649
802
  ff_linear_module_down=ff_linear_module_down,
650
803
  msa_scaling=msa_scaling,
651
- mlp_dropout=mlp_dropout,
804
+ ff_dropout=ff_dropout,
805
+ ff_inner_dropout=ff_inner_dropout,
806
+ ff_outer_dropout=ff_outer_dropout,
652
807
  msa_dropout=msa_dropout,
653
808
  identity_probability=self.stochastic_depth_probabilities[i],
654
809
  causal=causal,
@@ -657,6 +812,8 @@ class TransformerEncoder(nn.Module):
657
812
  post_norm=post_norm,
658
813
  normformer=normformer,
659
814
  checkpoint_ff=checkpoint_ff,
815
+ alpha=alpha,
816
+ beta=beta,
660
817
  )
661
818
  for i in range(n_layers)
662
819
  ]
@@ -677,13 +834,14 @@ class TransformerEncoder(nn.Module):
677
834
  x = x
678
835
 
679
836
  if self.absolute_position_embedding is not None:
680
- x = x + self.absolute_position_embedding(
837
+ position_embedding = self.absolute_position_embedding(
681
838
  torch.arange(
682
839
  0, self.full_sequence_length, dtype=torch.long, device=x.device
683
840
  ).unsqueeze(
684
841
  0
685
842
  ) # to shape (1, seq_len) to broadcast over batch
686
843
  )
844
+ x += position_embedding
687
845
 
688
846
  return x
689
847
 
broccoli/vit.py CHANGED
@@ -158,10 +158,11 @@ class ViTEncoder(nn.Module):
158
158
  pooling_kernel_stride=2,
159
159
  pooling_padding=1,
160
160
  transformer_feedforward_first=True,
161
- transformer_initial_ff_residual_path=True,
162
161
  transformer_initial_ff_linear_module_up=None,
163
162
  transformer_initial_ff_linear_module_down=None,
164
- transformer_initial_ff_mlp_dropout=None,
163
+ transformer_initial_ff_dropout=None,
164
+ transformer_initial_ff_inner_dropout=None,
165
+ transformer_initial_ff_outer_dropout=None,
165
166
  transformer_pre_norm=True,
166
167
  transformer_normformer=False,
167
168
  transformer_post_norm=False,
@@ -172,20 +173,28 @@ class ViTEncoder(nn.Module):
172
173
  transformer_heads=4,
173
174
  transformer_mlp_ratio=2,
174
175
  transformer_utility_tokens=0,
176
+ transformer_talking_heads=False,
175
177
  transformer_return_utility_tokens=False,
176
178
  transformer_activation: nn.Module = SquaredReLU,
177
179
  transformer_activation_kwargs: Optional[dict] = None,
178
180
  transformer_ff_linear_module_up=None,
179
181
  transformer_ff_linear_module_down=None,
180
182
  transformer_msa_scaling="d",
181
- transformer_mlp_dropout=0.0,
183
+ transformer_ff_dropout=0.0,
184
+ transformer_ff_inner_dropout=0.0,
185
+ transformer_ff_outer_dropout=0.0,
182
186
  transformer_msa_dropout=0.1,
183
187
  transformer_stochastic_depth=0.1,
184
188
  transformer_checkpoint_ff=True,
185
189
  linear_module=nn.Linear,
190
+ alpha=1.0,
191
+ beta=1.0,
186
192
  ):
187
193
  super().__init__()
188
194
 
195
+ self.alpha = alpha
196
+ self.beta = beta
197
+
189
198
  if cnn_activation_kwargs is not None:
190
199
  self.cnn_activation = cnn_activation(**cnn_activation_kwargs)
191
200
  else:
@@ -333,17 +342,22 @@ class ViTEncoder(nn.Module):
333
342
  ff_linear_module_up=transformer_ff_linear_module_up,
334
343
  ff_linear_module_down=transformer_ff_linear_module_down,
335
344
  msa_scaling=transformer_msa_scaling,
336
- mlp_dropout=transformer_mlp_dropout,
345
+ ff_dropout=transformer_ff_dropout,
346
+ ff_inner_dropout=transformer_ff_inner_dropout,
347
+ ff_outer_dropout=transformer_ff_outer_dropout,
337
348
  msa_dropout=transformer_msa_dropout,
338
349
  stochastic_depth=transformer_stochastic_depth,
339
350
  causal=False,
340
351
  linear_module=linear_module,
341
352
  utility_tokens=transformer_utility_tokens,
353
+ talking_heads=transformer_talking_heads,
342
354
  return_utility_tokens=transformer_return_utility_tokens,
343
355
  pre_norm=transformer_pre_norm,
344
356
  normformer=transformer_normformer,
345
357
  post_norm=transformer_post_norm,
346
358
  checkpoint_ff=transformer_checkpoint_ff,
359
+ alpha=self.alpha,
360
+ beta=self.beta,
347
361
  )
348
362
  else:
349
363
  self.transformer = nn.Identity()
@@ -357,9 +371,21 @@ class ViTEncoder(nn.Module):
357
371
  activation_kwargs=transformer_activation_kwargs,
358
372
  dropout=(
359
373
  # First truthy assigned value
360
- transformer_initial_ff_mlp_dropout
361
- if transformer_initial_ff_mlp_dropout is not None
362
- else transformer_mlp_dropout
374
+ transformer_initial_ff_dropout
375
+ if transformer_initial_ff_dropout is not None
376
+ else transformer_ff_dropout
377
+ ),
378
+ inner_dropout=(
379
+ # First truthy assigned value
380
+ transformer_initial_ff_inner_dropout
381
+ if transformer_initial_ff_inner_dropout is not None
382
+ else transformer_ff_inner_dropout
383
+ ),
384
+ outer_dropout=(
385
+ # First truthy assigned value
386
+ transformer_initial_ff_outer_dropout
387
+ if transformer_initial_ff_outer_dropout is not None
388
+ else transformer_ff_outer_dropout
363
389
  ),
364
390
  linear_module_up=(
365
391
  # First truthy assigned value
@@ -373,16 +399,14 @@ class ViTEncoder(nn.Module):
373
399
  or transformer_ff_linear_module_down
374
400
  or linear_module
375
401
  ),
376
- pre_norm=transformer_pre_norm,
377
402
  normformer=transformer_normformer,
378
- post_norm=transformer_post_norm,
379
- residual_path=transformer_initial_ff_residual_path,
380
403
  checkpoint=transformer_checkpoint_ff,
404
+ beta=self.beta,
381
405
  )
382
406
  else:
383
407
  self.initial_ff = nn.Identity()
384
408
 
385
- self.encoder = nn.Sequential(
409
+ self.preprocess = nn.Sequential(
386
410
  *[
387
411
  batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
388
412
  self.cnn,
@@ -392,19 +416,21 @@ class ViTEncoder(nn.Module):
392
416
  f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
393
417
  ),
394
418
  self.pooling_channels_padding,
395
- self.initial_ff,
396
- self.transformer,
419
+ nn.LayerNorm(),
397
420
  ]
398
421
  )
399
422
 
400
423
  self.reset_parameters()
401
424
 
402
425
  def forward(self, x):
403
- return self.encoder(x)
426
+ x = self.preprocess(x)
427
+ x = x + self.initial_ff(x)
428
+ return self.transformer(x)
404
429
 
405
430
  def attention_logits(self, x):
406
- x = self.encoder[:-1](x)
407
- return self.encoder[-1].attention_logits(x)
431
+ x = self.preprocess(x)
432
+ x = x + self.initial_ff(x)
433
+ return self.transformer.attention_logits(x)
408
434
 
409
435
  def reset_parameters(self):
410
436
  for module in self.encoder:
@@ -438,10 +464,11 @@ class ViT(nn.Module):
438
464
  pooling_kernel_stride=2,
439
465
  pooling_padding=1,
440
466
  transformer_feedforward_first=True,
441
- transformer_initial_ff_residual_path=True,
442
467
  transformer_initial_ff_linear_module_up=None,
443
468
  transformer_initial_ff_linear_module_down=None,
444
- transformer_initial_ff_mlp_dropout=None,
469
+ transformer_initial_ff_dropout=None,
470
+ transformer_initial_ff_inner_dropout=None,
471
+ transformer_initial_ff_outer_dropout=None,
445
472
  transformer_pre_norm=True,
446
473
  transformer_normformer=False,
447
474
  transformer_post_norm=False,
@@ -452,13 +479,16 @@ class ViT(nn.Module):
452
479
  transformer_heads=4,
453
480
  transformer_mlp_ratio=2,
454
481
  transformer_utility_tokens=0,
482
+ transformer_talking_heads=False,
455
483
  transformer_return_utility_tokens=False,
456
484
  transformer_activation: nn.Module = SquaredReLU,
457
485
  transformer_activation_kwargs: Optional[dict] = None,
458
486
  transformer_ff_linear_module_up=None,
459
487
  transformer_ff_linear_module_down=None,
460
488
  transformer_msa_scaling="d",
461
- transformer_mlp_dropout=0.0,
489
+ transformer_ff_dropout=0.0,
490
+ transformer_ff_inner_dropout=0.0,
491
+ transformer_ff_outer_dropout=0.0,
462
492
  transformer_msa_dropout=0.1,
463
493
  transformer_stochastic_depth=0.1,
464
494
  transformer_checkpoint_ff=True,
@@ -466,6 +496,8 @@ class ViT(nn.Module):
466
496
  batch_norm_logits=True,
467
497
  logit_projection_layer=nn.Linear,
468
498
  linear_module=nn.Linear,
499
+ alpha=1.0,
500
+ beta=1.0,
469
501
  ):
470
502
 
471
503
  super().__init__()
@@ -486,6 +518,9 @@ class ViT(nn.Module):
486
518
  "SwiGLU": SwiGLU,
487
519
  }[transformer_activation]
488
520
 
521
+ self.alpha = alpha
522
+ self.beta = beta
523
+
489
524
  self.encoder = ViTEncoder(
490
525
  input_size=input_size,
491
526
  initial_batch_norm=initial_batch_norm,
@@ -505,10 +540,11 @@ class ViT(nn.Module):
505
540
  pooling_kernel_stride=pooling_kernel_stride,
506
541
  pooling_padding=pooling_padding,
507
542
  transformer_feedforward_first=transformer_feedforward_first,
508
- transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
509
543
  transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
510
544
  transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
511
- transformer_initial_ff_mlp_dropout=transformer_initial_ff_mlp_dropout,
545
+ transformer_initial_ff_dropout=transformer_initial_ff_dropout,
546
+ transformer_initial_ff_inner_dropout=transformer_initial_ff_inner_dropout,
547
+ transformer_initial_ff_outer_dropout=transformer_initial_ff_outer_dropout,
512
548
  transformer_pre_norm=transformer_pre_norm,
513
549
  transformer_normformer=transformer_normformer,
514
550
  transformer_post_norm=transformer_post_norm,
@@ -519,17 +555,22 @@ class ViT(nn.Module):
519
555
  transformer_heads=transformer_heads,
520
556
  transformer_mlp_ratio=transformer_mlp_ratio,
521
557
  transformer_utility_tokens=transformer_utility_tokens,
558
+ transformer_talking_heads=transformer_talking_heads,
522
559
  transformer_return_utility_tokens=transformer_return_utility_tokens,
523
560
  transformer_activation=transformer_activation,
524
561
  transformer_activation_kwargs=transformer_activation_kwargs,
525
562
  transformer_ff_linear_module_up=transformer_ff_linear_module_up,
526
563
  transformer_ff_linear_module_down=transformer_ff_linear_module_down,
527
564
  transformer_msa_scaling=transformer_msa_scaling,
528
- transformer_mlp_dropout=transformer_mlp_dropout,
565
+ transformer_ff_dropout=transformer_ff_dropout,
566
+ transformer_ff_inner_dropout=transformer_ff_inner_dropout,
567
+ transformer_ff_outer_dropout=transformer_ff_outer_dropout,
529
568
  transformer_msa_dropout=transformer_msa_dropout,
530
569
  transformer_stochastic_depth=transformer_stochastic_depth,
531
570
  transformer_checkpoint_ff=transformer_checkpoint_ff,
532
571
  linear_module=linear_module,
572
+ alpha=alpha,
573
+ beta=beta,
533
574
  )
534
575
 
535
576
  self.pool = head(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 6.0.0
3
+ Version: 13.0.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -0,0 +1,13 @@
1
+ broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
2
+ broccoli/activation.py,sha256=nrpTOrpg9k23_E4AJWy7VlXXAJCtCJCOR-TonEWJr04,3218
3
+ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
+ broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
5
+ broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
+ broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
+ broccoli/transformer.py,sha256=3vAQQ75SAyr4-m3e7vSru8M-RpUy2Enp5cVUafaVYMU,28410
8
+ broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
+ broccoli/vit.py,sha256=jd4e6MjL2JKB8ynSQssWRh6Hs36RuLj4uWyUNVhIMUY,22472
10
+ broccoli_ml-13.0.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-13.0.1.dist-info/METADATA,sha256=gN9cKQDpwRr8JG_Ilj0ZknlPdxhq62I8xEBExVspKGw,1369
12
+ broccoli_ml-13.0.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-13.0.1.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
2
- broccoli/activation.py,sha256=-Jf30C6iGqWCorC9HEGn2oduWwjeaCAxGLUUYIy1zX8,3438
3
- broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
- broccoli/linear.py,sha256=Y7s-DzcwsOipRboNHc4HTScw4mJRalNoVFsNcxOB6a4,4872
5
- broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
- broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
- broccoli/transformer.py,sha256=MxIdzoxoWx_IWcq86vDZJIV4tk-dMNivhopZu8zJk90,23293
8
- broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
- broccoli/vit.py,sha256=9oyh76ulmX5lDPMCDicQhhqm8RYCvJIgAJkDbYRVdi4,20873
10
- broccoli_ml-6.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
- broccoli_ml-6.0.0.dist-info/METADATA,sha256=Sv8nRPb7oCAeoMe3AAHIYDewAETvb0ZDxN8IKFniVHk,1368
12
- broccoli_ml-6.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
- broccoli_ml-6.0.0.dist-info/RECORD,,