x-transformers 2.1.4__py3-none-any.whl → 2.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,11 +35,12 @@ class BeliefStateWrapper(Module):
35
35
  def __init__(
36
36
  self,
37
37
  forward_decoder: TransformerWrapper,
38
- backward_decoder: TransformerWrapper
38
+ backward_decoder: TransformerWrapper,
39
+ train_frac_forward_backward_pairs: float = 1.
39
40
  ):
40
41
  super().__init__()
41
42
  assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension'
42
- assert forward_decoder.num_tokens == backward_decoder.num_tokens, 'forward and backwards model must have the same embedding dimension'
43
+ assert forward_decoder.num_tokens == backward_decoder.num_tokens, 'forward and backwards model must have the same number of tokens'
43
44
 
44
45
  dim = forward_decoder.emb_dim
45
46
  num_tokens = forward_decoder.num_tokens
@@ -47,6 +48,7 @@ class BeliefStateWrapper(Module):
47
48
  # the suffix token
48
49
 
49
50
  self.suffix_token = nn.Parameter(torch.zeros(dim))
51
+ nn.init.normal_(self.suffix_token, std = 0.02)
50
52
 
51
53
  # the text prediction head, which predicts for the combinations of prefix and suffix the next and previous token for forwards and backward sequences
52
54
 
@@ -61,6 +63,13 @@ class BeliefStateWrapper(Module):
61
63
  self.forward_decoder = forward_decoder
62
64
  self.backward_decoder = backward_decoder
63
65
 
66
+ # what fraction of forward backward pairs to train on
67
+ # for further memory efficiency
68
+
69
+ assert 0 < train_frac_forward_backward_pairs <= 1.
70
+ self.train_frac_fb_pairs = train_frac_forward_backward_pairs
71
+ self.needs_subsample_fb_pairs = train_frac_forward_backward_pairs < 1.
72
+
64
73
  def forward(
65
74
  self,
66
75
  seq
@@ -107,6 +116,17 @@ class BeliefStateWrapper(Module):
107
116
 
108
117
  fb_pairs = fb_pairs[valid_mask]
109
118
 
119
+ # maybe subsample fb pairs
120
+
121
+ if self.needs_subsample_fb_pairs:
122
+ num_pairs = fb_pairs.shape[0]
123
+
124
+ num_subsampled = max(int(num_pairs * self.train_frac_fb_pairs), 1)
125
+
126
+ rand_subsampled_indices = torch.randperm(num_pairs, device = device)[:num_subsampled]
127
+
128
+ fb_pairs = fb_pairs[rand_subsampled_indices]
129
+
110
130
  # get labels for both
111
131
 
112
132
  fi, bi = fb_pairs.unbind(dim = -1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.4
3
+ Version: 2.1.5
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state.py,sha256=5E_08m6kvXROYgROazTJFxuUsDPNjIVM3AJxg3CJNmU,3966
4
+ x_transformers/belief_state.py,sha256=GfYDeDqmhldozECgFsJ9zhd6O5NMvdYA5OwueVs8SB4,4742
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=-80N4sqUr3sR51Ms4wCfc4jhxnPwf0ApNR4xfIsasfQ,110142
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.4.dist-info/METADATA,sha256=-jme9jyXVeVlo1T7nPi2iGFfoxfjSwdj2D33_dr4yxQ,87570
14
- x_transformers-2.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.4.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.4.dist-info/RECORD,,
13
+ x_transformers-2.1.5.dist-info/METADATA,sha256=-nQpm1eBGBXkLEXGiQK06NaIWa13CEnBujBNfTzvnJ8,87570
14
+ x_transformers-2.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.5.dist-info/RECORD,,