broccoli-ml 6.0.0__py3-none-any.whl → 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -13,6 +13,7 @@ from .rope import RotaryEmbedding, apply_rotary_emb
13
13
  try:
14
14
  from flash_attn import flash_attn_func
15
15
 
16
+ print("Using flash-attn.")
16
17
  FLASH_ATTN = True
17
18
  except ImportError:
18
19
  pass
@@ -590,7 +591,13 @@ class TransformerEncoder(nn.Module):
590
591
  if relative_position_embedding and (source_size is None):
591
592
  raise ValueError(
592
593
  "`source_size` for TransformerEncoder cannot be None if"
593
- " `position_embedding_type` is relative"
594
+ " `relative_position_embedding` is True"
595
+ )
596
+
597
+ if absolute_position_embedding and (seq_len is None):
598
+ raise ValueError(
599
+ "`seq_len` for TransformerEncoder cannot be None if"
600
+ " `absolute_position_embedding` is True"
594
601
  )
595
602
 
596
603
  super().__init__()
@@ -605,9 +612,12 @@ class TransformerEncoder(nn.Module):
605
612
  torch.empty(self._utility_tokens, d_model)
606
613
  )
607
614
  nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
608
- self.full_sequence_length = self.seq_len + self._utility_tokens
609
615
  else:
610
616
  self._utility_token_embedding = None
617
+
618
+ if self._utility_tokens and (self.seq_len is not None):
619
+ self.full_sequence_length = self.seq_len + self._utility_tokens
620
+ else:
611
621
  self.full_sequence_length = self.seq_len
612
622
 
613
623
  self.d_model = d_model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 6.0.0
3
+ Version: 7.0.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
4
  broccoli/linear.py,sha256=Y7s-DzcwsOipRboNHc4HTScw4mJRalNoVFsNcxOB6a4,4872
5
5
  broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
6
  broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
- broccoli/transformer.py,sha256=MxIdzoxoWx_IWcq86vDZJIV4tk-dMNivhopZu8zJk90,23293
7
+ broccoli/transformer.py,sha256=-NRO8mzvOkQPCuTrF6OTDQ8sUIQ3_j4HxP-NfTUxu10,23636
8
8
  broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
9
  broccoli/vit.py,sha256=9oyh76ulmX5lDPMCDicQhhqm8RYCvJIgAJkDbYRVdi4,20873
10
- broccoli_ml-6.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
- broccoli_ml-6.0.0.dist-info/METADATA,sha256=Sv8nRPb7oCAeoMe3AAHIYDewAETvb0ZDxN8IKFniVHk,1368
12
- broccoli_ml-6.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
- broccoli_ml-6.0.0.dist-info/RECORD,,
10
+ broccoli_ml-7.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-7.0.0.dist-info/METADATA,sha256=OdAph_5ItouyYHeZwVgbrTiyexljFb9uKNLNWlnR4wM,1368
12
+ broccoli_ml-7.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-7.0.0.dist-info/RECORD,,