broccoli-ml 5.0.0__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import math
2
- from typing import Optional
2
+ from typing import Optional, Tuple
3
3
 
4
4
  import torch
5
5
  import torch.nn as nn
@@ -102,6 +102,15 @@ class MHAttention(nn.Module):
102
102
 
103
103
  self.head_dim = self.embed_dim // self.n_heads
104
104
 
105
+ if self.scaling == "sqrtd":
106
+ self.scaling_factor = 1 / math.sqrt(self.head_dim)
107
+ elif self.scaling == "d":
108
+ # 8/d_model for backwards compatibility,
109
+ # per https://github.com/microsoft/mup
110
+ self.scaling_factor = 8 / self.head_dim
111
+ else:
112
+ raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
113
+
105
114
  self.q_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
106
115
  self.k_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
107
116
  self.v_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
@@ -122,6 +131,8 @@ class MHAttention(nn.Module):
122
131
  self.source_size = source_size
123
132
  self.bos_tokens = bos_tokens
124
133
 
134
+ self.reset_parameters()
135
+
125
136
  @property
126
137
  def _kv_distance(self) -> float:
127
138
  """
@@ -141,7 +152,71 @@ class MHAttention(nn.Module):
141
152
 
142
153
  return 1 - similarity
143
154
 
144
- def forward(self, q, k, v):
155
+ def add_axial_rope(
156
+ self, q: torch.Tensor, k: torch.Tensor
157
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
158
+ """
159
+ Apply Axial RoPE to all tokens except BOS tokens
160
+ """
161
+
162
+ if len(self.source_size) == 1:
163
+ spatial_dimension_names = "D1"
164
+ spatial_dimension_values = {"D1": self.source_size[0]}
165
+ elif len(self.source_size) == 2:
166
+ spatial_dimension_names = "D1 D2"
167
+ spatial_dimension_values = {
168
+ "D1": self.source_size[0],
169
+ "D2": self.source_size[1],
170
+ }
171
+ elif len(self.source_size) == 3:
172
+ spatial_dimension_names = "D1 D2 D3"
173
+ spatial_dimension_values = {
174
+ "D1": self.source_size[0],
175
+ "D2": self.source_size[1],
176
+ "D3": self.source_size[2],
177
+ }
178
+ else:
179
+ raise NotImplementedError(
180
+ "`source_size` must be a tuple of 1, 2 or 3 integers"
181
+ )
182
+
183
+ q_bos, q_img = q[:, : self.bos_tokens, :], q[:, self.bos_tokens :, :]
184
+ k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
185
+
186
+ q_img = rearrange(
187
+ q_img,
188
+ f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
189
+ **spatial_dimension_values,
190
+ )
191
+ k_img = rearrange(
192
+ k_img,
193
+ f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
194
+ **spatial_dimension_values,
195
+ )
196
+
197
+ freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
198
+
199
+ q_img = apply_rotary_emb(freqs, q_img)
200
+ k_img = apply_rotary_emb(freqs, k_img)
201
+
202
+ q_img = rearrange(
203
+ q_img,
204
+ f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
205
+ )
206
+ k_img = rearrange(
207
+ k_img,
208
+ f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
209
+ )
210
+
211
+ # Re-combine the BOS tokens and the RoPE-enhanced image tokens
212
+ q = torch.cat([q_bos, q_img], dim=1)
213
+ k = torch.cat([k_bos, k_img], dim=1)
214
+
215
+ return q, k
216
+
217
+ def project_qkv(
218
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
219
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
145
220
  query_batch_size, query_tokens, query_features = q.size()
146
221
  key_batch_size, key_tokens, key_features = k.size()
147
222
 
@@ -154,74 +229,18 @@ class MHAttention(nn.Module):
154
229
 
155
230
  if self.causal:
156
231
  assert query_tokens == key_tokens
157
- assert query_tokens == self.sequence_length
232
+ assert query_tokens == self.seq_len
158
233
 
159
- # Project q, k and v
160
- q = self.q_proj(q)
161
- k = self.k_proj(k)
162
- v = self.v_proj(v)
234
+ q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
163
235
 
164
- # Rearrange dimensions and add RoPE if needed
165
236
  if self.rotary_embedding is not None:
237
+ q, k = self.add_axial_rope(q, k)
166
238
 
167
- if len(self.source_size) == 1:
168
- spatial_dimension_names = "D1"
169
- spatial_dimension_values = {"D1": self.source_size[0]}
170
- elif len(self.source_size) == 2:
171
- spatial_dimension_names = "D1 D2"
172
- spatial_dimension_values = {
173
- "D1": self.source_size[0],
174
- "D2": self.source_size[1],
175
- }
176
- elif len(self.source_size) == 3:
177
- spatial_dimension_names = "D1 D2 D3"
178
- spatial_dimension_values = {
179
- "D1": self.source_size[0],
180
- "D2": self.source_size[1],
181
- "D3": self.source_size[2],
182
- }
183
- else:
184
- raise NotImplementedError(
185
- "`source_size` must be a tuple of 1, 2 or 3 integers"
186
- )
187
-
188
- q_bos, q_img = q[:, : self.bos_tokens, :], q[:, self.bos_tokens :, :]
189
- k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
190
-
191
- q_img = rearrange(
192
- q_img,
193
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
194
- **spatial_dimension_values,
195
- )
196
- k_img = rearrange(
197
- k_img,
198
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
199
- **spatial_dimension_values,
200
- )
201
- freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
202
- q_img = apply_rotary_emb(freqs, q_img)
203
- k_img = apply_rotary_emb(freqs, k_img)
204
-
205
- q_img = rearrange(
206
- q_img,
207
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
208
- )
209
- k_img = rearrange(
210
- k_img,
211
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
212
- )
239
+ return q, k, v
213
240
 
214
- # Re-combine the BOS tokens and the RoPE-enhanced image tokens
215
- q = torch.cat([q_bos, q_img], dim=1)
216
- k = torch.cat([k_bos, k_img], dim=1)
241
+ def forward(self, q, k, v):
217
242
 
218
- if self.scaling == "sqrtd":
219
- scaling_factor = 1 / math.sqrt(self.head_dim)
220
- elif self.scaling == "d":
221
- # for backwards compatibility, per https://github.com/microsoft/mup
222
- scaling_factor = 8 / self.head_dim
223
- else:
224
- raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
243
+ q, k, v = self.project_qkv(q, k, v)
225
244
 
226
245
  if FLASH_ATTN:
227
246
  # Divide Q/K/V into heads
@@ -234,7 +253,7 @@ class MHAttention(nn.Module):
234
253
  k,
235
254
  v,
236
255
  dropout_p=self.dropout.p if self.training else 0.0,
237
- softmax_scale=scaling_factor,
256
+ softmax_scale=self.scaling_factor,
238
257
  causal=self.causal,
239
258
  )
240
259
 
@@ -249,7 +268,7 @@ class MHAttention(nn.Module):
249
268
 
250
269
  qk_scores = q @ k.transpose(-1, -2)
251
270
 
252
- qk_scores *= scaling_factor
271
+ qk_scores *= self.scaling_factor
253
272
 
254
273
  # Apply mask if causal (must come before softmax)
255
274
  if self.causal:
@@ -265,6 +284,34 @@ class MHAttention(nn.Module):
265
284
 
266
285
  return self.out_proj(output_without_heads)
267
286
 
287
+ def attention_scores(self, q, k, v):
288
+
289
+ q, k, v = self.project_qkv(q, k, v)
290
+
291
+ # Divide Q/K/V into heads
292
+ q = rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
293
+ k = rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
294
+ v = rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
295
+
296
+ qk_scores = q @ k.transpose(-1, -2)
297
+
298
+ qk_scores *= self.scaling_factor
299
+
300
+ # Apply mask if causal (must come before softmax)
301
+ if self.causal:
302
+ qk_scores.masked_fill_(self.mask, float("-inf"))
303
+
304
+ qk_scores = F.softmax(qk_scores, dim=-1)
305
+
306
+ return qk_scores # (batch, head, seq_len, seq_len)
307
+
308
+ def reset_parameters(self):
309
+ # Default nn.Linear init is kaiming_uniform, which is fine
310
+ self.q_proj.reset_parameters()
311
+ self.k_proj.reset_parameters()
312
+ self.v_proj.reset_parameters()
313
+ self.out_proj.reset_parameters()
314
+
268
315
 
269
316
  class FeedforwardBlock(nn.Module):
270
317
  """
@@ -327,6 +374,8 @@ class FeedforwardBlock(nn.Module):
327
374
  ]
328
375
  )
329
376
 
377
+ self.reset_parameters()
378
+
330
379
  def forward(self, x):
331
380
 
332
381
  if self.checkpoint:
@@ -341,6 +390,15 @@ class FeedforwardBlock(nn.Module):
341
390
  else:
342
391
  return processed
343
392
 
393
+ def reset_parameters(self):
394
+ if self.post_norm:
395
+ self.layernorm.reset_parameters()
396
+
397
+ # Iterate over the sequential block to reset parameters
398
+ for module in self.process:
399
+ if hasattr(module, "reset_parameters"):
400
+ module.reset_parameters()
401
+
344
402
 
345
403
  class TransformerBlock(nn.Module):
346
404
  """
@@ -445,6 +503,8 @@ class TransformerBlock(nn.Module):
445
503
  checkpoint=checkpoint_ff,
446
504
  )
447
505
 
506
+ self.reset_parameters()
507
+
448
508
  @property
449
509
  def _kv_distance(self) -> float:
450
510
  return self.attn._kv_distance
@@ -469,11 +529,29 @@ class TransformerBlock(nn.Module):
469
529
 
470
530
  return x
471
531
 
532
+ def attention_scores(self, x):
533
+ """
534
+ Give back the attention scores used in this layer.
535
+ """
536
+ if self.pre_norm:
537
+ x = self.layer_norm_1(x)
538
+ return self.attn(x, x, x)
539
+ else:
540
+ return self.attn(x, x, x)
541
+
542
+ def reset_parameters(self):
543
+ self.layer_norm_1.reset_parameters()
544
+ self.layer_norm_2.reset_parameters()
545
+ self.layer_norm_3.reset_parameters()
546
+
547
+ self.attn.reset_parameters()
548
+ self.ff.reset_parameters()
549
+
472
550
 
473
551
  class TransformerEncoder(nn.Module):
474
552
  """
475
553
  This assumes we already get a sequence of embeddings (e.g. word or image
476
- patch embeddings). It uses learned positional embeddings.
554
+ patch embeddings).
477
555
  """
478
556
 
479
557
  def __init__(
@@ -584,11 +662,13 @@ class TransformerEncoder(nn.Module):
584
662
  ]
585
663
  )
586
664
 
665
+ self.reset_parameters()
666
+
587
667
  @property
588
668
  def _kv_distances(self) -> float:
589
669
  return ",".join([str(block._kv_distance) for block in self.blocks])
590
670
 
591
- def forward(self, x):
671
+ def preprocess(self, x):
592
672
  if self._bos_tokens:
593
673
  x = torch.cat([self._bos_embedding.expand(x.size(0), -1, -1), x], dim=1)
594
674
  else:
@@ -603,6 +683,10 @@ class TransformerEncoder(nn.Module):
603
683
  ) # to shape (1, seq_len) to broadcast over batch
604
684
  )
605
685
 
686
+ def forward(self, x):
687
+
688
+ x = self.preprocess(x)
689
+
606
690
  for block in self.blocks:
607
691
  x = block(x)
608
692
 
@@ -610,3 +694,27 @@ class TransformerEncoder(nn.Module):
610
694
  return x[:, self._bos_tokens :, :]
611
695
  else:
612
696
  return x
697
+
698
+ def attention_scores(self, x):
699
+
700
+ x = self.preprocess(x)
701
+
702
+ layer_scores = []
703
+
704
+ for block in self.blocks:
705
+ # Get attention scores with shape (batch, 1, head, seq_len, seq_len)
706
+ layer_attention_scores = block.attention_scores(x).unsqueeze(1)
707
+ layer_scores.append(layer_attention_scores)
708
+ x = block(x)
709
+
710
+ return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
711
+
712
+ def reset_parameters(self):
713
+ if self._bos_embedding is not None:
714
+ nn.init.normal_(self._bos_embedding, mean=0.0, std=1.0)
715
+
716
+ if self.absolute_position_embedding is not None:
717
+ self.absolute_position_embedding.reset_parameters()
718
+
719
+ for block in self.blocks:
720
+ block.reset_parameters()
broccoli/vit.py CHANGED
@@ -9,7 +9,9 @@ from .utils import PadTensor
9
9
  from einops import einsum
10
10
  from einops.layers.torch import Rearrange
11
11
 
12
+ import torch
12
13
  import torch.nn as nn
14
+ import torch.nn.functional as F
13
15
 
14
16
 
15
17
  class GetCLSToken(nn.Module):
@@ -31,10 +33,18 @@ class SequencePool(nn.Module):
31
33
  ]
32
34
  )
33
35
 
36
+ self.reset_parameters()
37
+
34
38
  def forward(self, x):
35
39
  weights = self.attention(x)
36
40
  return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
37
41
 
42
+ def reset_parameters(self):
43
+ # Iterate over modules in the sequential block
44
+ for module in self.attention:
45
+ if hasattr(module, "reset_parameters"):
46
+ module.reset_parameters()
47
+
38
48
 
39
49
  class ClassificationHead(nn.Module):
40
50
  """
@@ -71,9 +81,16 @@ class ClassificationHead(nn.Module):
71
81
  ]
72
82
  )
73
83
 
84
+ self.reset_parameters()
85
+
74
86
  def forward(self, x):
75
87
  return self.classification_process(x)
76
88
 
89
+ def reset_parameters(self):
90
+ for module in self.classification_process:
91
+ if hasattr(module, "reset_parameters"):
92
+ module.reset_parameters()
93
+
77
94
 
78
95
  class SequencePoolClassificationHead(ClassificationHead):
79
96
  """
@@ -106,6 +123,8 @@ class SequencePoolClassificationHead(ClassificationHead):
106
123
  ]
107
124
  )
108
125
 
126
+ self.reset_parameters()
127
+
109
128
 
110
129
  class ViTEncoder(nn.Module):
111
130
  """
@@ -376,9 +395,20 @@ class ViTEncoder(nn.Module):
376
395
  ]
377
396
  )
378
397
 
398
+ self.reset_parameters()
399
+
379
400
  def forward(self, x):
380
401
  return self.encoder(x)
381
402
 
403
+ def attention_scores(self, x):
404
+ x = self.encoder[:-1](x)
405
+ return self.encoder[-1].attention_scores(x)
406
+
407
+ def reset_parameters(self):
408
+ for module in self.encoder:
409
+ if hasattr(module, "reset_parameters"):
410
+ module.reset_parameters()
411
+
382
412
 
383
413
  class ViT(nn.Module):
384
414
  """
@@ -507,9 +537,26 @@ class ViT(nn.Module):
507
537
  batch_norm_logits=batch_norm_logits,
508
538
  )
509
539
 
540
+ self.reset_parameters()
541
+
510
542
  @property
511
543
  def sequence_length(self):
512
544
  return self.encoder.sequence_length
513
545
 
514
546
  def forward(self, x):
515
547
  return self.pool(self.encoder(x))
548
+
549
+ def attention_scores(self, x):
550
+ return self.encoder.attention_scores(x)
551
+
552
+ def head_to_bos_token_attention(self, x):
553
+ all_attention = self.attention_scores(x)
554
+ batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
555
+ sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
556
+ n_bos_tokens = self.encoder.encoder._bos_tokens
557
+ just_bos = sequence_averages[:, :, :n_bos_tokens]
558
+ return F.softmax(just_bos, dim=-1) # (layer, head, bos_token)
559
+
560
+ def reset_parameters(self):
561
+ self.encoder.reset_parameters()
562
+ self.pool.reset_parameters()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 5.0.0
3
+ Version: 5.1.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
4
  broccoli/linear.py,sha256=Y7s-DzcwsOipRboNHc4HTScw4mJRalNoVFsNcxOB6a4,4872
5
5
  broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
6
  broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
- broccoli/transformer.py,sha256=eSBRF-HYJ-BxfisJCueUYCIYtHXgj1ewG5RxEfcmu-E,20128
7
+ broccoli/transformer.py,sha256=x3Mo6_1x6fGG6lPDPx9srxn6UdwKEpvjFAO8zoMwAMI,23052
8
8
  broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
- broccoli/vit.py,sha256=BrNLOx4_gTY6xTwAn8xT-HOgUnSFtU6_m1CpJXuQiKY,18907
10
- broccoli_ml-5.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
- broccoli_ml-5.0.0.dist-info/METADATA,sha256=NMkRLZfqhMZdBIn4BMHjF_-jyv3yYUlOrKeHRTt2rnE,1368
12
- broccoli_ml-5.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
- broccoli_ml-5.0.0.dist-info/RECORD,,
9
+ broccoli/vit.py,sha256=tUYQyoDsBc5ZR_M5_J0huj0T3OAy-vn1f19hCGVDCrM,20425
10
+ broccoli_ml-5.1.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-5.1.0.dist-info/METADATA,sha256=3986lqn1iuWJ53O8ckM9LVU3tTjr32i19SeIXauWDXw,1368
12
+ broccoli_ml-5.1.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-5.1.0.dist-info/RECORD,,