x-transformers 2.4.8__py3-none-any.whl → 2.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,7 @@ from random import randrange, uniform
7
7
 
8
8
  import torch
9
9
  from torch import nn, cat, tensor, randperm
10
- from torch.nn import LSTM, Module
10
+ from torch.nn import LSTM, GRU, Module
11
11
 
12
12
  from x_transformers.x_transformers import (
13
13
  TransformerWrapper,
@@ -61,7 +61,9 @@ class SyntheticDataGenerator(Module):
61
61
  dim,
62
62
  num_tokens,
63
63
  max_seq_len = 512,
64
- hidden_size = None
64
+ hidden_size = None,
65
+ use_gru = False,
66
+ network_klass = None
65
67
  ):
66
68
  super().__init__()
67
69
 
@@ -70,7 +72,11 @@ class SyntheticDataGenerator(Module):
70
72
  self.embed = nn.Embedding(num_tokens, dim)
71
73
 
72
74
  hidden_size = default(hidden_size, dim)
73
- self.lstm = LSTM(dim, hidden_size, batch_first = True)
75
+
76
+ default_network_klass = partial(LSTM if not use_gru else GRU, batch_first = True)
77
+ network_klass = default(network_klass, default_network_klass)
78
+
79
+ self.net = network_klass(dim, hidden_size)
74
80
 
75
81
  self.to_logits = nn.Linear(dim, num_tokens, bias = False)
76
82
 
@@ -128,7 +134,7 @@ class SyntheticDataGenerator(Module):
128
134
 
129
135
  tokens = self.embed(input)
130
136
 
131
- embed, hidden = self.lstm(tokens, hiddens)
137
+ embed, hidden = self.net(tokens, hiddens)
132
138
 
133
139
  logits = self.to_logits(embed)
134
140
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.8
3
+ Version: 2.4.9
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -8,11 +8,11 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/up_wrapper.py,sha256=J4cIqqEDpm9pkG-QsdlsNzUHbovOxB1iR1zYGqnDxAM,6864
11
+ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
12
12
  x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.4.8.dist-info/METADATA,sha256=tqhX2F0TzwpBfPtla3n2_6SATVJGtQKfoBbfHk9Fvzw,90223
16
- x_transformers-2.4.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.4.8.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.4.8.dist-info/RECORD,,
15
+ x_transformers-2.4.9.dist-info/METADATA,sha256=yRYvqg0EZr7jvv-sRfBC2iU2tCu_pUo37KtLxYu44hg,90223
16
+ x_transformers-2.4.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.4.9.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.4.9.dist-info/RECORD,,