robo-lib 0.0.10__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
robo_lib/components.py CHANGED
@@ -615,6 +615,7 @@ class RoboConstructor(nn.Module):
615
615
  enc_vocab_size:int=None,
616
616
  enc_block_size:int=None,
617
617
  enc_expansion_factor:int=4,
618
+ enc_positional_encoding:bool=True,
618
619
  dropout:float=0.1,
619
620
  device:str=None
620
621
  ) -> None:
@@ -635,6 +636,7 @@ class RoboConstructor(nn.Module):
635
636
  self.dec_positional_embedding_table = nn.Embedding(dec_block_size, n_embed)
636
637
 
637
638
  if enc_n_blocks != 0:
639
+ self.enc_positional_encoding = enc_positional_encoding
638
640
  self.enc_n_blocks = enc_n_blocks
639
641
  self.enc_n_head = enc_n_head
640
642
  self.enc_expansion_factor = enc_expansion_factor
@@ -642,7 +644,8 @@ class RoboConstructor(nn.Module):
642
644
  self.enc_block_size = enc_block_size
643
645
  self.cross_attention = True
644
646
  self.enc_token_embedding_table = nn.Embedding(enc_vocab_size, n_embed)
645
- self.enc_positional_embedding_table = nn.Embedding(enc_block_size, n_embed)
647
+ if enc_positional_encoding:
648
+ self.enc_positional_embedding_table = nn.Embedding(enc_block_size, n_embed)
646
649
  self.encoder_blocks = MySequential(*[EncoderBlock(n_embed, enc_n_head, enc_expansion_factor, dropout=dropout) for _ in range(enc_n_blocks)])
647
650
  else:
648
651
  self.cross_attention = False
@@ -685,8 +688,11 @@ class RoboConstructor(nn.Module):
685
688
 
686
689
  if self.cross_attention:
687
690
  enc_tok_emb = self.enc_token_embedding_table(enc_in)
688
- enc_pos_emb = self.enc_positional_embedding_table(torch.arange(enc_T, device=self.device))
689
- enc_x = enc_tok_emb + enc_pos_emb
691
+ if self.enc_positional_encoding:
692
+ enc_pos_emb = self.enc_positional_embedding_table(torch.arange(enc_T, device=self.device))
693
+ enc_x = enc_tok_emb + enc_pos_emb
694
+ else:
695
+ enc_x = enc_tok_emb
690
696
 
691
697
  enc_out, enc_mask = self.encoder_blocks(enc_x, enc_mask)
692
698
  else:
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: robo_lib
3
- Version: 0.0.10
3
+ Version: 0.0.11
4
4
  Summary: A package to create, configure, and train transformer models.
5
5
  Project-URL: Homepage, https://github.com/hamburgerfish/robo_pack
6
6
  Project-URL: Issues, https://github.com/hamburgerfish/robo_pack/issues
@@ -0,0 +1,6 @@
1
+ robo_lib/__init__.py,sha256=iVOAsANj0lScVW9KKMxCULYmpp0cv4sv1k3sHjBSlE0,1012
2
+ robo_lib/components.py,sha256=L_GUEHdKC_-Xn56ObQ9-DH8T1ywaz0M8jlWv227gZBs,42591
3
+ robo_lib-0.0.11.dist-info/METADATA,sha256=ePF06l2FXzo0qjK8v9Vob4WnOQ61KVd0mUqd7JVG7j4,9634
4
+ robo_lib-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ robo_lib-0.0.11.dist-info/licenses/LICENSE,sha256=4XzkkpFqPzH0GH3zxOqRTqc7xUKSEe7dWPOuJYW95ac,1089
6
+ robo_lib-0.0.11.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.25.0
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,6 +0,0 @@
1
- robo_lib/__init__.py,sha256=iVOAsANj0lScVW9KKMxCULYmpp0cv4sv1k3sHjBSlE0,1012
2
- robo_lib/components.py,sha256=OjusjkSlMlAsTEq1kSqixKXG9sBw8Re8hsXTEy_bJ48,42315
3
- robo_lib-0.0.10.dist-info/METADATA,sha256=a30lSFG-Eo9UGFQErA64MTbeVqCeD8BwViXMmB2OPX4,9634
4
- robo_lib-0.0.10.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
5
- robo_lib-0.0.10.dist-info/licenses/LICENSE,sha256=4XzkkpFqPzH0GH3zxOqRTqc7xUKSEe7dWPOuJYW95ac,1089
6
- robo_lib-0.0.10.dist-info/RECORD,,