broccoli-ml 9.2.2__py3-none-any.whl → 13.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/activation.py CHANGED
@@ -46,10 +46,7 @@ class GELU(nn.Module):
46
46
 
47
47
  class Swish(nn.Module):
48
48
  """
49
- Implementation of (beta) SwiGLU, as introduced in "GLU Variants Improve Transformer"
50
- (https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
51
-
52
- Halves the incoming parameter count, which should be scaled up before input.
49
+ Implementation of (beta) Swish
53
50
  """
54
51
 
55
52
  def __init__(self) -> None:
broccoli/linear.py CHANGED
@@ -151,11 +151,13 @@ class RecyclingLinear(nn.Module):
151
151
  row_recycling_rate: float = 0.0,
152
152
  column_recycling_rate: float = 0.0,
153
153
  adaptive=False,
154
+ xglu=False,
154
155
  ):
155
156
  super().__init__()
156
157
  self.in_features = in_features
157
158
  self.out_features = out_features
158
159
  self.bias = bias
160
+ self.xglu = xglu
159
161
  self.linear = nn.Linear(in_features, out_features, bias=bias)
160
162
  self.row_recycling_rate = row_recycling_rate
161
163
  self.column_recycling_rate = column_recycling_rate
@@ -191,28 +193,60 @@ class RecyclingLinear(nn.Module):
191
193
  multipliers = [a / b for a, b in pairs if b != 0.0]
192
194
  return min(multipliers) if multipliers else 0.0
193
195
 
194
- def forward(self, x):
195
- multiplier = self._get_multiplier()
196
- col_recycling_rate = self.column_recycling_rate * multiplier
197
- row_recycling_rate = self.row_recycling_rate * multiplier
196
+ def reset_rows(self, indices, incoming_data=None):
197
+ """
198
+ Resets rows.
199
+ If incoming_data is provided, resets to the centroid (mean) of that data.
200
+ If not, resets to the mean of existing weights.
201
+ """
202
+ if not torch.is_tensor(indices):
203
+ idx_tensor = torch.as_tensor(
204
+ list(indices), dtype=torch.long, device=self.linear.weight.device
205
+ )
206
+ else:
207
+ idx_tensor = indices
198
208
 
199
- if self.training and self.optimisers:
209
+ if idx_tensor.numel() == 0:
210
+ return
211
+
212
+ if incoming_data is not None:
213
+ target_center = self._mean_input_weights(incoming_data)
214
+ else:
215
+ target_center = self._mean_value_weights()
200
216
 
201
- if row_recycling_rate > 0:
202
- probs = torch.rand(self.linear.out_features, device=x.device)
203
- mask = probs < row_recycling_rate
204
- if mask.any():
205
- # nonzero returns [N, 1], squeeze to get [N]
206
- indices = torch.nonzero(mask).squeeze(-1)
207
- self.reset_rows(indices, self.optimisers)
208
-
209
- if col_recycling_rate > 0:
210
- probs = torch.rand(self.linear.in_features, device=x.device)
211
- mask = probs < col_recycling_rate
212
- if mask.any():
213
- indices = torch.nonzero(mask).squeeze(-1)
214
- self.reset_columns(indices, self.optimisers)
217
+ target_center = target_center.expand(idx_tensor.size(0), -1)
215
218
 
219
+ if self.xglu:
220
+ gate_indices = idx_tensor
221
+ value_indices = idx_tensor + (self.linear.out_features // 2)
222
+ self._update_weights(gate_indices, 0, target_center, self.optimisers)
223
+ self._update_weights(value_indices, 0, target_center, self.optimisers)
224
+ else:
225
+ self._update_weights(idx_tensor, 0, target_center, self.optimisers)
226
+
227
+ def reset_columns(self, indices):
228
+ if not torch.is_tensor(indices):
229
+ idx_tensor = torch.as_tensor(
230
+ list(indices), dtype=torch.long, device=self.linear.weight.device
231
+ )
232
+ else:
233
+ idx_tensor = indices
234
+
235
+ if idx_tensor.size(0):
236
+ random_weights = self._random_weights(
237
+ self.linear.weight.size(0), indices.size(0)
238
+ )
239
+ # Make random col weights quiet so they don't introduce loud noise...
240
+ # ...but not so quiet that FP16 zeros them and ruins symmetry breaking!
241
+ random_weights *= 0.1
242
+ self._update_weights(indices, 1, random_weights, self.optimisers) # dim
243
+ else:
244
+ return
245
+
246
+ def forward(self, x):
247
+ if self.training and self.optimisers:
248
+ self.reset_rows(self.get_reset_indices(0))
249
+ self.reset_columns(self.get_reset_indices(1))
216
250
  elif self.training and not self._warned_about_registration:
217
251
  warnings.warn(
218
252
  "RecyclingLinear: No optimiser registered. Recycling disabled.",
@@ -222,82 +256,99 @@ class RecyclingLinear(nn.Module):
222
256
 
223
257
  return self.linear(x)
224
258
 
225
- def reset_rows(
226
- self,
227
- indices: Iterable[int],
228
- optimisers: Union[
229
- List[torch.optim.Optimizer], torch.optim.Optimizer, None
230
- ] = None,
231
- ):
232
- """
233
- Update some of the weight rows to be equal to the mean of all weight rows.
234
- """
235
- if optimisers is None:
236
- optimisers = []
237
- if not isinstance(optimisers, list):
238
- optimisers = [optimisers]
259
+ def get_reset_indices(self, dim):
260
+ base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
261
+ p = base_rate * self._get_multiplier()
262
+ if dim == 0:
263
+ if self.xglu:
264
+ sample_space = self.linear.out_features // 2
265
+ else:
266
+ sample_space = self.linear.out_features
267
+ elif dim == 1:
268
+ sample_space = self.linear.in_features
269
+ else:
270
+ raise ValueError("`dim` must be 0 or 1")
239
271
 
272
+ # Sample the indices
273
+ probs = torch.rand(sample_space, device=self.linear.weight.device)
274
+ mask = probs < p
275
+ if mask.any():
276
+ return torch.nonzero(mask).squeeze(-1)
277
+ else:
278
+ return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
279
+
280
+ def _random_weights(self, rows, columns):
240
281
  device = self.linear.weight.device
241
- idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
282
+ weights = self.linear.weight.data
283
+ stdv = 1.0 / math.sqrt(weights.size(1))
284
+ random_weights = torch.rand(rows, columns, device=device)
285
+ random_weights -= 0.5 # Range [-0.5, +0.5]
286
+ random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
287
+ return random_weights
242
288
 
243
- if idx_tensor.numel() == 0:
244
- return
289
+ def _mean_input_weights(self, input):
290
+ reduce_dims = list(range(input.ndim - 1))
291
+ data_mean = input.detach().mean(dim=reduce_dims, keepdim=True)
245
292
 
246
- with torch.no_grad():
247
- # Calculate mean of all rows including the rows to be reset
248
- mean_vector = self.linear.weight.data.mean(
249
- dim=0, keepdim=True
250
- ) # [1, in_features]
251
- update_data = mean_vector.expand(idx_tensor.size(0), -1)
252
- self.linear.weight.data[idx_tensor] = update_data
293
+ weights = self.linear.weight.data
294
+ stdv = 1.0 / math.sqrt(weights.size(1))
295
+ data_norm = data_mean.std() + 1e-6
296
+ scale_factor = stdv / data_norm
297
+
298
+ return data_mean * scale_factor
253
299
 
254
- if self.linear.bias is not None:
255
- self.linear.bias.data[idx_tensor] = 0.0
300
+ def _mean_value_weights(self):
301
+ """
302
+ Only used when self.xglu
303
+ """
304
+ weights = self.linear.weight.data
305
+ rows = weights.size(0)
306
+ if self.xglu:
307
+ return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
308
+ else:
309
+ return self.linear.weight.data.mean(dim=0, keepdim=True)
256
310
 
257
- self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=0)
258
- if self.linear.bias is not None:
259
- self._reset_optim_state(self.linear.bias, idx_tensor, optimisers, dim=0)
311
+ def _mean_gate_weights(self):
312
+ """
313
+ Only used when self.xglu
314
+ """
315
+ weights = self.linear.weight.data
316
+ rows = weights.size(0)
317
+ return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
260
318
 
261
- def reset_columns(
319
+ def _update_weights(
262
320
  self,
263
321
  indices: Iterable[int],
322
+ dim: int,
323
+ data: torch.Tensor,
264
324
  optimisers: Union[
265
325
  List[torch.optim.Optimizer], torch.optim.Optimizer, None
266
326
  ] = None,
267
327
  ):
268
- """
269
- Update some of the weight columns to be random as though reinitialised.
270
- """
271
328
  if optimisers is None:
272
329
  optimisers = []
273
330
  if not isinstance(optimisers, list):
274
331
  optimisers = [optimisers]
275
332
 
276
- device = self.linear.weight.device
277
- idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
333
+ if not torch.is_tensor(indices):
334
+ idx_tensor = torch.as_tensor(
335
+ list(indices), dtype=torch.long, device=self.linear.weight.device
336
+ )
337
+ else:
338
+ idx_tensor = indices
278
339
 
279
340
  if idx_tensor.numel() == 0:
280
341
  return
281
342
 
282
343
  with torch.no_grad():
283
- # 1. Generate Random Columns
284
- # Shape: [out_features, N_indices]
285
- weights = self.linear.weight.data
286
- stdv = 1.0 / math.sqrt(weights.size(1))
287
-
288
- # Generate [Rows, N] block
289
- random_weights = torch.rand(
290
- weights.size(0), idx_tensor.size(0), device=device
291
- )
292
- random_weights = (random_weights - 0.5) * 2.0 * stdv
293
-
294
- # 2. Update Weights (One-shot)
295
- # We assign into the columns specified by idx_tensor
296
- self.linear.weight.data[:, idx_tensor] = random_weights
297
-
298
- # 3. Update Optimizers
299
- # Bias is untouched by column resets (bias is shape [Out], cols are [In])
300
- self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=1)
344
+ if dim == 0:
345
+ self.linear.weight.data[idx_tensor] = data
346
+ elif dim == 1:
347
+ self.linear.weight.data[:, idx_tensor] = data
348
+ else:
349
+ raise ValueError("`dim` must be 0 or 1")
350
+
351
+ self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
301
352
 
302
353
  def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
303
354
  """
broccoli/transformer.py CHANGED
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import math
2
3
  from typing import Optional, Tuple
3
4
 
@@ -20,6 +21,12 @@ except ImportError:
20
21
  FLASH_ATTN = False
21
22
 
22
23
 
24
+ def scale_parameters(torch_module: nn.Module, factor: float):
25
+ with torch.no_grad():
26
+ for param in torch_module.parameters():
27
+ param.mul_(factor)
28
+
29
+
23
30
  def drop_path(
24
31
  x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
25
32
  ):
@@ -78,9 +85,11 @@ class MHAttention(nn.Module):
78
85
  seq_len=None,
79
86
  linear_module: nn.Module = nn.Linear,
80
87
  utility_tokens=0,
88
+ talking_heads=False,
81
89
  rotary_embedding=None,
82
90
  source_size=None,
83
91
  scaling="d",
92
+ beta=1.0,
84
93
  ):
85
94
  """
86
95
  Args:
@@ -96,10 +105,20 @@ class MHAttention(nn.Module):
96
105
  if causal:
97
106
  assert seq_len is not None
98
107
 
108
+ self.talking_heads = talking_heads
109
+
110
+ if self.talking_heads:
111
+ self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
112
+ self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
113
+ else:
114
+ self.head_projection = None
115
+ self.sample_projection = None
116
+
99
117
  self.embed_dim = embed_dim
100
118
  self.n_heads = n_heads
101
119
  assert embed_dim % n_heads == 0
102
120
  self.scaling = scaling
121
+ self.beta = beta
103
122
 
104
123
  self.head_dim = self.embed_dim // self.n_heads
105
124
 
@@ -181,17 +200,26 @@ class MHAttention(nn.Module):
181
200
  "`source_size` must be a tuple of 1, 2 or 3 integers"
182
201
  )
183
202
 
184
- q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
185
- k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
203
+ q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
204
+ k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
205
+
206
+ q_util, q_img = (
207
+ q[:, : self.utility_tokens, :, :],
208
+ q[:, self.utility_tokens :, :, :],
209
+ )
210
+ k_util, k_img = (
211
+ k[:, : self.utility_tokens, :, :],
212
+ k[:, self.utility_tokens :, :, :],
213
+ )
186
214
 
187
215
  q_img = rearrange(
188
216
  q_img,
189
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
217
+ f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
190
218
  **spatial_dimension_values,
191
219
  )
192
220
  k_img = rearrange(
193
221
  k_img,
194
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
222
+ f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
195
223
  **spatial_dimension_values,
196
224
  )
197
225
 
@@ -202,17 +230,20 @@ class MHAttention(nn.Module):
202
230
 
203
231
  q_img = rearrange(
204
232
  q_img,
205
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
233
+ f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
206
234
  )
207
235
  k_img = rearrange(
208
236
  k_img,
209
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
237
+ f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
210
238
  )
211
239
 
212
240
  # Re-combine the utility tokens and the RoPE-enhanced sequence tokens
213
241
  q = torch.cat([q_util, q_img], dim=1)
214
242
  k = torch.cat([k_util, k_img], dim=1)
215
243
 
244
+ q = rearrange(q, "b t h d -> b t (h d)")
245
+ k = rearrange(k, "b t h d -> b t (h d)")
246
+
216
247
  return q, k
217
248
 
218
249
  def project_qkv(
@@ -243,7 +274,7 @@ class MHAttention(nn.Module):
243
274
 
244
275
  q, k, v = self.project_qkv(q, k, v)
245
276
 
246
- if FLASH_ATTN:
277
+ if FLASH_ATTN and not self.talking_heads:
247
278
  # Divide Q/K/V into heads
248
279
  q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
249
280
  k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
@@ -271,12 +302,22 @@ class MHAttention(nn.Module):
271
302
 
272
303
  qk_scores *= self.scaling_factor
273
304
 
305
+ if self.talking_heads:
306
+ qk_scores = torch.einsum(
307
+ "b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
308
+ )
309
+
274
310
  # Apply mask if causal (must come before softmax)
275
311
  if self.causal:
276
312
  qk_scores.masked_fill_(self.mask, float("-inf"))
277
313
 
278
314
  qk_scores = F.softmax(qk_scores, dim=-1)
279
315
 
316
+ if self.talking_heads:
317
+ qk_scores = torch.einsum(
318
+ "b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
319
+ )
320
+
280
321
  qk_scores = self.dropout(qk_scores)
281
322
 
282
323
  output_with_heads = qk_scores @ v
@@ -309,7 +350,14 @@ class MHAttention(nn.Module):
309
350
  self.q_proj.reset_parameters()
310
351
  self.k_proj.reset_parameters()
311
352
  self.v_proj.reset_parameters()
353
+ scale_parameters(self.v_proj, self.beta) # per Microsoft DeepNet
312
354
  self.out_proj.reset_parameters()
355
+ scale_parameters(self.out_proj, self.beta) # per Microsoft DeepNet
356
+
357
+ if self.talking_heads:
358
+ # Initialize close to identity
359
+ nn.init.eye_(self.head_projection.weight)
360
+ nn.init.eye_(self.sample_projection.weight)
313
361
 
314
362
 
315
363
  class FeedforwardBlock(nn.Module):
@@ -329,17 +377,14 @@ class FeedforwardBlock(nn.Module):
329
377
  outer_dropout=None,
330
378
  linear_module_up=nn.Linear,
331
379
  linear_module_down=nn.Linear,
332
- pre_norm=True,
333
380
  normformer=False,
334
- post_norm=True,
335
- residual_path=True,
336
381
  checkpoint=True,
382
+ beta=1.0,
337
383
  ):
338
384
  super().__init__()
339
385
 
340
386
  self.checkpoint = checkpoint
341
- self.residual_path = residual_path
342
- self.post_norm = post_norm
387
+ self.beta = beta
343
388
  self.xglu = activation.__name__.endswith("GLU")
344
389
 
345
390
  if self.residual_path and (output_features < input_features):
@@ -365,19 +410,26 @@ class FeedforwardBlock(nn.Module):
365
410
  )
366
411
 
367
412
  self.max_features = (
368
- 2 * ratio * output_features if self.xglu else ratio * output_features
413
+ 2 * int(ratio * output_features)
414
+ if self.xglu
415
+ else int(ratio * output_features)
369
416
  )
370
417
 
371
418
  self.linear_in = linear_module_up(input_features, self.max_features)
372
- self.linear_out = linear_module_down(ratio * output_features, output_features)
419
+ self.linear_out = linear_module_down(
420
+ int(ratio * output_features), output_features
421
+ )
373
422
 
374
423
  self.process = nn.Sequential(
375
424
  *[
376
- nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
377
425
  self.linear_in,
378
426
  self.activation,
379
427
  self.inner_dropout,
380
- nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
428
+ (
429
+ nn.LayerNorm(int(ratio * output_features))
430
+ if normformer
431
+ else nn.Identity()
432
+ ),
381
433
  self.linear_out,
382
434
  self.outer_dropout,
383
435
  ]
@@ -410,33 +462,16 @@ class FeedforwardBlock(nn.Module):
410
462
 
411
463
  # Recycle weights if using recycling linear layers
412
464
  if self.training and self.recycling_enabled:
413
- multiplier = self.linear_in._get_multiplier()
414
- rate = self.master_recycling_rate * multiplier
415
- if rate > 0:
416
- probs = torch.rand(self.linear_out.in_features, device=x.device)
417
- mask = probs < rate
418
- if mask.any():
419
- indices = torch.nonzero(mask).squeeze(-1)
420
- self.linear_out.reset_columns(indices, self.linear_out.optimisers)
421
- if self.xglu:
422
- indices_in = torch.cat(
423
- [indices, indices + self.linear_out.in_features]
424
- )
425
- self.linear_in.reset_rows(indices_in, self.linear_in.optimisers)
426
- else:
427
- self.linear_in.reset_rows(indices, self.linear_in.optimisers)
465
+ indices = self.linear_out.get_reset_indices(1)
466
+ self.linear_in.reset_rows(indices, incoming_data=x)
467
+ self.linear_out.reset_columns(indices)
428
468
 
429
469
  if self.checkpoint:
430
470
  processed = checkpoint(self.process, x, use_reentrant=False)
431
471
  else:
432
472
  processed = self.process(x)
433
473
 
434
- if self.residual_path and self.post_norm:
435
- return self.layernorm(x + processed)
436
- elif self.residual_path:
437
- return x + processed
438
- else:
439
- return processed
474
+ return processed
440
475
 
441
476
  def reset_parameters(self):
442
477
  if self.post_norm:
@@ -447,8 +482,11 @@ class FeedforwardBlock(nn.Module):
447
482
  if hasattr(module, "reset_parameters"):
448
483
  module.reset_parameters()
449
484
 
485
+ scale_parameters(self.linear_in, self.beta) # per Microsoft DeepNet
486
+ scale_parameters(self.linear_out, self.beta)
487
+
450
488
 
451
- class TransformerBlock(nn.Module):
489
+ class EncoderBlock(nn.Module):
452
490
  """
453
491
  Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
454
492
  which is also what is seen in e.g.
@@ -465,6 +503,7 @@ class TransformerBlock(nn.Module):
465
503
  relative_position_embedding=False,
466
504
  source_size=None,
467
505
  utility_tokens=0,
506
+ talking_heads=False,
468
507
  mlp_ratio=4,
469
508
  activation: nn.Module = nn.ReLU,
470
509
  activation_kwargs: Optional[dict] = None,
@@ -482,6 +521,8 @@ class TransformerBlock(nn.Module):
482
521
  post_norm=False,
483
522
  normformer=False,
484
523
  checkpoint_ff=True,
524
+ alpha=1.0,
525
+ beta=1.0,
485
526
  ):
486
527
  """
487
528
  Args:
@@ -493,15 +534,29 @@ class TransformerBlock(nn.Module):
493
534
 
494
535
  super().__init__()
495
536
 
537
+ if pre_norm and post_norm:
538
+ raise ValueError("A transformer cannot be both prenorm and postnorm.")
539
+
496
540
  self.pre_norm = pre_norm
497
541
  self.post_norm = post_norm
498
542
  self.normformer = normformer
499
543
 
544
+ self.alpha = alpha
545
+ self.beta = beta
546
+
500
547
  self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
501
548
 
502
- self.layer_norm_1 = nn.LayerNorm(d_model)
503
- self.layer_norm_2 = nn.LayerNorm(d_model)
504
- self.layer_norm_3 = nn.LayerNorm(d_model)
549
+ if self.pre_norm:
550
+ self.pre_attention_norm = nn.LayerNorm(d_model)
551
+ self.pre_mlp_norm = nn.LayerNorm(d_model)
552
+
553
+ if normformer:
554
+ self.normformer_norm = nn.LayerNorm(d_model)
555
+
556
+ if self.post_norm:
557
+ self.input_norm = nn.LayerNorm(d_model)
558
+ self.post_attention_norm = nn.LayerNorm(d_model)
559
+ self.post_mlp_norm = nn.LayerNorm(d_model)
505
560
 
506
561
  if relative_position_embedding:
507
562
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
@@ -525,7 +580,9 @@ class TransformerBlock(nn.Module):
525
580
  rotary_embedding=self.rotary_embedding,
526
581
  source_size=source_size,
527
582
  utility_tokens=utility_tokens,
583
+ talking_heads=talking_heads,
528
584
  scaling=msa_scaling,
585
+ beta=beta,
529
586
  )
530
587
 
531
588
  # Submodule for the feedforward process
@@ -548,11 +605,9 @@ class TransformerBlock(nn.Module):
548
605
  if ff_linear_module_down is not None
549
606
  else linear_module
550
607
  ),
551
- pre_norm=False, # Handled outside the block
552
608
  normformer=normformer,
553
- post_norm=False, # Handled outside the block
554
- residual_path=False, # Handled outside the block
555
609
  checkpoint=checkpoint_ff,
610
+ beta=beta,
556
611
  )
557
612
 
558
613
  self.reset_parameters()
@@ -562,22 +617,34 @@ class TransformerBlock(nn.Module):
562
617
  return self.attn._kv_distance
563
618
 
564
619
  def forward(self, x):
620
+ if self.post_norm:
621
+ x = self.input_norm(x)
565
622
 
566
623
  if self.pre_norm:
567
- x = self.layer_norm_1(x)
568
- x = x + self.drop_path(self.attn(x, x, x))
569
- x = self.layer_norm_2(x)
570
- x = x + self.drop_path(self.ff(x))
571
- if self.post_norm: # i.e. in addition! Pre and post.
572
- x = self.layer_norm_3(x)
573
- elif self.post_norm: # i.e. only, not prenorm, just post
574
- x = x + self.drop_path(self.attn(x, x, x))
575
- x = self.layer_norm_1(x)
576
- x = x + self.drop_path(self.ff(x))
577
- x = self.layer_norm_2(x)
578
- else: # Not pre or post norm. Stand well back.
579
- x = x + self.drop_path(self.attn(x, x, x))
580
- x = x + self.drop_path(self.ff(x))
624
+ process_x = self.pre_attention_norm(x)
625
+ else:
626
+ process_x = x
627
+
628
+ processed = self.drop_path(self.attn(process_x, process_x, process_x))
629
+
630
+ if self.normformer:
631
+ processed = self.normformer_norm(processed)
632
+
633
+ x = self.alpha * x + processed
634
+
635
+ if self.post_norm:
636
+ x = self.post_attention_norm(x)
637
+ elif self.pre_norm:
638
+ process_x = self.pre_mlp_norm(x)
639
+ else:
640
+ process_x = x
641
+
642
+ processed = self.drop_path(self.ff(process_x))
643
+
644
+ x = self.alpha * x + processed
645
+
646
+ if self.post_norm:
647
+ x = self.post_mlp_norm(x)
581
648
 
582
649
  return x
583
650
 
@@ -585,16 +652,26 @@ class TransformerBlock(nn.Module):
585
652
  """
586
653
  Give back the attention scores used in this layer.
587
654
  """
655
+ # Fix: Use the correct attribute name 'pre_attention_norm'
588
656
  if self.pre_norm:
589
- x = self.layer_norm_1(x)
657
+ # We must normalize the input before measuring attention logits
658
+ # to match what the model actually sees during forward()
659
+ x = self.pre_attention_norm(x)
590
660
  return self.attn.attention_logits(x, x, x)
591
661
  else:
592
662
  return self.attn.attention_logits(x, x, x)
593
663
 
594
664
  def reset_parameters(self):
595
- self.layer_norm_1.reset_parameters()
596
- self.layer_norm_2.reset_parameters()
597
- self.layer_norm_3.reset_parameters()
665
+ if self.pre_norm:
666
+ self.pre_attention_norm.reset_parameters()
667
+ self.pre_mlp_norm.reset_parameters()
668
+
669
+ if self.post_norm:
670
+ self.post_attention_norm.reset_parameters()
671
+ self.post_mlp_norm.reset_parameters()
672
+
673
+ if self.normformer:
674
+ self.normformer_norm.reset_parameters()
598
675
 
599
676
  self.attn.reset_parameters()
600
677
  self.ff.reset_parameters()
@@ -628,12 +705,15 @@ class TransformerEncoder(nn.Module):
628
705
  causal=False,
629
706
  linear_module=nn.Linear,
630
707
  utility_tokens=0,
708
+ talking_heads=False,
631
709
  return_utility_tokens=False,
632
710
  pre_norm=True,
633
711
  post_norm=False,
634
712
  normformer=False,
635
713
  msa_scaling="d",
636
714
  checkpoint_ff=True,
715
+ alpha=1.0,
716
+ beta=1.0,
637
717
  ):
638
718
  """
639
719
  Args:
@@ -656,6 +736,13 @@ class TransformerEncoder(nn.Module):
656
736
  )
657
737
 
658
738
  super().__init__()
739
+
740
+ if FLASH_ATTN and talking_heads:
741
+ warnings.warn(
742
+ "Using talking heads currently prevents using flash attention.",
743
+ stacklevel=2,
744
+ )
745
+
659
746
  self.seq_len = seq_len
660
747
  self.n_heads = n_heads
661
748
  self._utility_tokens = utility_tokens
@@ -700,13 +787,14 @@ class TransformerEncoder(nn.Module):
700
787
 
701
788
  self.blocks = nn.ModuleList(
702
789
  [
703
- TransformerBlock(
790
+ EncoderBlock(
704
791
  self.full_sequence_length,
705
792
  d_model,
706
793
  n_heads,
707
794
  relative_position_embedding=relative_position_embedding,
708
795
  source_size=source_size,
709
796
  utility_tokens=utility_tokens,
797
+ talking_heads=talking_heads,
710
798
  mlp_ratio=mlp_ratio,
711
799
  activation=activation,
712
800
  activation_kwargs=activation_kwargs,
@@ -724,6 +812,8 @@ class TransformerEncoder(nn.Module):
724
812
  post_norm=post_norm,
725
813
  normformer=normformer,
726
814
  checkpoint_ff=checkpoint_ff,
815
+ alpha=alpha,
816
+ beta=beta,
727
817
  )
728
818
  for i in range(n_layers)
729
819
  ]
@@ -744,13 +834,14 @@ class TransformerEncoder(nn.Module):
744
834
  x = x
745
835
 
746
836
  if self.absolute_position_embedding is not None:
747
- x = x + self.absolute_position_embedding(
837
+ position_embedding = self.absolute_position_embedding(
748
838
  torch.arange(
749
839
  0, self.full_sequence_length, dtype=torch.long, device=x.device
750
840
  ).unsqueeze(
751
841
  0
752
842
  ) # to shape (1, seq_len) to broadcast over batch
753
843
  )
844
+ x += position_embedding
754
845
 
755
846
  return x
756
847
 
broccoli/vit.py CHANGED
@@ -158,7 +158,6 @@ class ViTEncoder(nn.Module):
158
158
  pooling_kernel_stride=2,
159
159
  pooling_padding=1,
160
160
  transformer_feedforward_first=True,
161
- transformer_initial_ff_residual_path=True,
162
161
  transformer_initial_ff_linear_module_up=None,
163
162
  transformer_initial_ff_linear_module_down=None,
164
163
  transformer_initial_ff_dropout=None,
@@ -174,6 +173,7 @@ class ViTEncoder(nn.Module):
174
173
  transformer_heads=4,
175
174
  transformer_mlp_ratio=2,
176
175
  transformer_utility_tokens=0,
176
+ transformer_talking_heads=False,
177
177
  transformer_return_utility_tokens=False,
178
178
  transformer_activation: nn.Module = SquaredReLU,
179
179
  transformer_activation_kwargs: Optional[dict] = None,
@@ -187,9 +187,14 @@ class ViTEncoder(nn.Module):
187
187
  transformer_stochastic_depth=0.1,
188
188
  transformer_checkpoint_ff=True,
189
189
  linear_module=nn.Linear,
190
+ alpha=1.0,
191
+ beta=1.0,
190
192
  ):
191
193
  super().__init__()
192
194
 
195
+ self.alpha = alpha
196
+ self.beta = beta
197
+
193
198
  if cnn_activation_kwargs is not None:
194
199
  self.cnn_activation = cnn_activation(**cnn_activation_kwargs)
195
200
  else:
@@ -345,11 +350,14 @@ class ViTEncoder(nn.Module):
345
350
  causal=False,
346
351
  linear_module=linear_module,
347
352
  utility_tokens=transformer_utility_tokens,
353
+ talking_heads=transformer_talking_heads,
348
354
  return_utility_tokens=transformer_return_utility_tokens,
349
355
  pre_norm=transformer_pre_norm,
350
356
  normformer=transformer_normformer,
351
357
  post_norm=transformer_post_norm,
352
358
  checkpoint_ff=transformer_checkpoint_ff,
359
+ alpha=self.alpha,
360
+ beta=self.beta,
353
361
  )
354
362
  else:
355
363
  self.transformer = nn.Identity()
@@ -391,16 +399,14 @@ class ViTEncoder(nn.Module):
391
399
  or transformer_ff_linear_module_down
392
400
  or linear_module
393
401
  ),
394
- pre_norm=transformer_pre_norm,
395
402
  normformer=transformer_normformer,
396
- post_norm=transformer_post_norm,
397
- residual_path=transformer_initial_ff_residual_path,
398
403
  checkpoint=transformer_checkpoint_ff,
404
+ beta=self.beta,
399
405
  )
400
406
  else:
401
407
  self.initial_ff = nn.Identity()
402
408
 
403
- self.encoder = nn.Sequential(
409
+ self.preprocess = nn.Sequential(
404
410
  *[
405
411
  batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
406
412
  self.cnn,
@@ -410,19 +416,21 @@ class ViTEncoder(nn.Module):
410
416
  f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
411
417
  ),
412
418
  self.pooling_channels_padding,
413
- self.initial_ff,
414
- self.transformer,
419
+ nn.LayerNorm(),
415
420
  ]
416
421
  )
417
422
 
418
423
  self.reset_parameters()
419
424
 
420
425
  def forward(self, x):
421
- return self.encoder(x)
426
+ x = self.preprocess(x)
427
+ x = x + self.initial_ff(x)
428
+ return self.transformer(x)
422
429
 
423
430
  def attention_logits(self, x):
424
- x = self.encoder[:-1](x)
425
- return self.encoder[-1].attention_logits(x)
431
+ x = self.preprocess(x)
432
+ x = x + self.initial_ff(x)
433
+ return self.transformer.attention_logits(x)
426
434
 
427
435
  def reset_parameters(self):
428
436
  for module in self.encoder:
@@ -456,7 +464,6 @@ class ViT(nn.Module):
456
464
  pooling_kernel_stride=2,
457
465
  pooling_padding=1,
458
466
  transformer_feedforward_first=True,
459
- transformer_initial_ff_residual_path=True,
460
467
  transformer_initial_ff_linear_module_up=None,
461
468
  transformer_initial_ff_linear_module_down=None,
462
469
  transformer_initial_ff_dropout=None,
@@ -472,6 +479,7 @@ class ViT(nn.Module):
472
479
  transformer_heads=4,
473
480
  transformer_mlp_ratio=2,
474
481
  transformer_utility_tokens=0,
482
+ transformer_talking_heads=False,
475
483
  transformer_return_utility_tokens=False,
476
484
  transformer_activation: nn.Module = SquaredReLU,
477
485
  transformer_activation_kwargs: Optional[dict] = None,
@@ -488,6 +496,8 @@ class ViT(nn.Module):
488
496
  batch_norm_logits=True,
489
497
  logit_projection_layer=nn.Linear,
490
498
  linear_module=nn.Linear,
499
+ alpha=1.0,
500
+ beta=1.0,
491
501
  ):
492
502
 
493
503
  super().__init__()
@@ -508,6 +518,9 @@ class ViT(nn.Module):
508
518
  "SwiGLU": SwiGLU,
509
519
  }[transformer_activation]
510
520
 
521
+ self.alpha = alpha
522
+ self.beta = beta
523
+
511
524
  self.encoder = ViTEncoder(
512
525
  input_size=input_size,
513
526
  initial_batch_norm=initial_batch_norm,
@@ -527,7 +540,6 @@ class ViT(nn.Module):
527
540
  pooling_kernel_stride=pooling_kernel_stride,
528
541
  pooling_padding=pooling_padding,
529
542
  transformer_feedforward_first=transformer_feedforward_first,
530
- transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
531
543
  transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
532
544
  transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
533
545
  transformer_initial_ff_dropout=transformer_initial_ff_dropout,
@@ -543,6 +555,7 @@ class ViT(nn.Module):
543
555
  transformer_heads=transformer_heads,
544
556
  transformer_mlp_ratio=transformer_mlp_ratio,
545
557
  transformer_utility_tokens=transformer_utility_tokens,
558
+ transformer_talking_heads=transformer_talking_heads,
546
559
  transformer_return_utility_tokens=transformer_return_utility_tokens,
547
560
  transformer_activation=transformer_activation,
548
561
  transformer_activation_kwargs=transformer_activation_kwargs,
@@ -556,6 +569,8 @@ class ViT(nn.Module):
556
569
  transformer_stochastic_depth=transformer_stochastic_depth,
557
570
  transformer_checkpoint_ff=transformer_checkpoint_ff,
558
571
  linear_module=linear_module,
572
+ alpha=alpha,
573
+ beta=beta,
559
574
  )
560
575
 
561
576
  self.pool = head(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.2.2
3
+ Version: 13.0.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -0,0 +1,13 @@
1
+ broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
2
+ broccoli/activation.py,sha256=nrpTOrpg9k23_E4AJWy7VlXXAJCtCJCOR-TonEWJr04,3218
3
+ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
+ broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
5
+ broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
+ broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
+ broccoli/transformer.py,sha256=3vAQQ75SAyr4-m3e7vSru8M-RpUy2Enp5cVUafaVYMU,28410
8
+ broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
+ broccoli/vit.py,sha256=jd4e6MjL2JKB8ynSQssWRh6Hs36RuLj4uWyUNVhIMUY,22472
10
+ broccoli_ml-13.0.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-13.0.1.dist-info/METADATA,sha256=gN9cKQDpwRr8JG_Ilj0ZknlPdxhq62I8xEBExVspKGw,1369
12
+ broccoli_ml-13.0.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-13.0.1.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
2
- broccoli/activation.py,sha256=-Jf30C6iGqWCorC9HEGn2oduWwjeaCAxGLUUYIy1zX8,3438
3
- broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
- broccoli/linear.py,sha256=IwvPAMbHqOYyz0g5WZyevPAhC1Pn0RTLniFM4E6lJoI,11511
5
- broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
- broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
- broccoli/transformer.py,sha256=r-ggAeNDW5QpBi9As1U9sIfxITBOx0WHk_K4zWpyTM8,26233
8
- broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
- broccoli/vit.py,sha256=sC6K3FK3a8ojOgvNWSWhuZHBtnFrrTQbsDdlagcKJH4,22224
10
- broccoli_ml-9.2.2.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
- broccoli_ml-9.2.2.dist-info/METADATA,sha256=8ySQYntl9czgYyEQN5nyPS31tjwYC8M8Mx_iYhvtbzg,1368
12
- broccoli_ml-9.2.2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
- broccoli_ml-9.2.2.dist-info/RECORD,,