robo-lib 0.0.9__tar.gz → 0.0.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {robo_lib-0.0.9 → robo_lib-0.0.11}/PKG-INFO +3 -3
- {robo_lib-0.0.9 → robo_lib-0.0.11}/README.md +1 -1
- {robo_lib-0.0.9 → robo_lib-0.0.11}/pyproject.toml +1 -1
- {robo_lib-0.0.9 → robo_lib-0.0.11}/robo_lib/components.py +9 -3
- {robo_lib-0.0.9 → robo_lib-0.0.11}/LICENSE +0 -0
- {robo_lib-0.0.9 → robo_lib-0.0.11}/robo_lib/__init__.py +0 -0
- {robo_lib-0.0.9 → robo_lib-0.0.11}/tests/__init__.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: robo_lib
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.11
|
4
4
|
Summary: A package to create, configure, and train transformer models.
|
5
5
|
Project-URL: Homepage, https://github.com/hamburgerfish/robo_pack
|
6
6
|
Project-URL: Issues, https://github.com/hamburgerfish/robo_pack/issues
|
@@ -124,7 +124,7 @@ robo = rl.RoboConstructor(
|
|
124
124
|
enc_block_size=100
|
125
125
|
)
|
126
126
|
|
127
|
-
robo.
|
127
|
+
robo.train_robo(
|
128
128
|
max_iters=20000,
|
129
129
|
eval_interval=200,
|
130
130
|
batch_size=128,
|
@@ -615,6 +615,7 @@ class RoboConstructor(nn.Module):
|
|
615
615
|
enc_vocab_size:int=None,
|
616
616
|
enc_block_size:int=None,
|
617
617
|
enc_expansion_factor:int=4,
|
618
|
+
enc_positional_encoding:bool=True,
|
618
619
|
dropout:float=0.1,
|
619
620
|
device:str=None
|
620
621
|
) -> None:
|
@@ -635,6 +636,7 @@ class RoboConstructor(nn.Module):
|
|
635
636
|
self.dec_positional_embedding_table = nn.Embedding(dec_block_size, n_embed)
|
636
637
|
|
637
638
|
if enc_n_blocks != 0:
|
639
|
+
self.enc_positional_encoding = enc_positional_encoding
|
638
640
|
self.enc_n_blocks = enc_n_blocks
|
639
641
|
self.enc_n_head = enc_n_head
|
640
642
|
self.enc_expansion_factor = enc_expansion_factor
|
@@ -642,7 +644,8 @@ class RoboConstructor(nn.Module):
|
|
642
644
|
self.enc_block_size = enc_block_size
|
643
645
|
self.cross_attention = True
|
644
646
|
self.enc_token_embedding_table = nn.Embedding(enc_vocab_size, n_embed)
|
645
|
-
|
647
|
+
if enc_positional_encoding:
|
648
|
+
self.enc_positional_embedding_table = nn.Embedding(enc_block_size, n_embed)
|
646
649
|
self.encoder_blocks = MySequential(*[EncoderBlock(n_embed, enc_n_head, enc_expansion_factor, dropout=dropout) for _ in range(enc_n_blocks)])
|
647
650
|
else:
|
648
651
|
self.cross_attention = False
|
@@ -685,8 +688,11 @@ class RoboConstructor(nn.Module):
|
|
685
688
|
|
686
689
|
if self.cross_attention:
|
687
690
|
enc_tok_emb = self.enc_token_embedding_table(enc_in)
|
688
|
-
|
689
|
-
|
691
|
+
if self.enc_positional_encoding:
|
692
|
+
enc_pos_emb = self.enc_positional_embedding_table(torch.arange(enc_T, device=self.device))
|
693
|
+
enc_x = enc_tok_emb + enc_pos_emb
|
694
|
+
else:
|
695
|
+
enc_x = enc_tok_emb
|
690
696
|
|
691
697
|
enc_out, enc_mask = self.encoder_blocks(enc_x, enc_mask)
|
692
698
|
else:
|
File without changes
|
File without changes
|
File without changes
|