x-transformers 2.1.7__py3-none-any.whl → 2.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,11 @@
1
+
1
2
  # Belief State Transformer
2
3
 
3
4
  # https://arxiv.org/abs/2410.23506
4
5
  # https://www.youtube.com/watch?v=aqhbRtB2Fyg
5
6
 
7
+ from __future__ import annotations
8
+
6
9
  import torch
7
10
  from torch.autograd import Function
8
11
  from torch.nn import Module, ModuleList
@@ -37,6 +40,7 @@ class BeliefStateWrapper(Module):
37
40
  forward_decoder: TransformerWrapper,
38
41
  backward_decoder: TransformerWrapper,
39
42
  train_frac_forward_backward_pairs: float = 1.,
43
+ text_head: Module | None = None,
40
44
  backward_ar_loss_weight: float = 1. # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
41
45
  ):
42
46
  super().__init__()
@@ -53,11 +57,14 @@ class BeliefStateWrapper(Module):
53
57
 
54
58
  # the text prediction head, which predicts for the combinations of prefix and suffix the next and previous token for forwards and backward sequences
55
59
 
56
- self.text_head = nn.Sequential(
57
- nn.Linear(dim * 2, dim),
58
- nn.LeakyReLU(),
59
- nn.Linear(dim, num_tokens * 2),
60
- )
60
+ if not exists(text_head):
61
+ text_head = nn.Sequential(
62
+ nn.Linear(dim * 2, dim),
63
+ nn.LeakyReLU(),
64
+ nn.Linear(dim, num_tokens * 2),
65
+ )
66
+
67
+ self.text_head = text_head
61
68
 
62
69
  # the two decoders, one which is causal forward, the other causal backwards
63
70
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.7
3
+ Version: 2.1.9
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=WoagJe_cHguWAWzhEAmeDD_TzUsQcMQkgP2Mt88wrAg,5827
4
+ x_transformers/belief_state_wrapper.py,sha256=bA-H9BqfyVRI5Q7GIcGbzdLjmon0CKKGHb08-BnpJOs,5990
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=-80N4sqUr3sR51Ms4wCfc4jhxnPwf0ApNR4xfIsasfQ,110142
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.7.dist-info/METADATA,sha256=D7enhdRYdhTddt4kGSS-rSsvtHgtFaLYQH-cVuYxUVM,87570
14
- x_transformers-2.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.7.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.7.dist-info/RECORD,,
13
+ x_transformers-2.1.9.dist-info/METADATA,sha256=YXRYQPw90873on2JyVr56N8Tz4ua2ibPawTUDvVE35g,87570
14
+ x_transformers-2.1.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.9.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.9.dist-info/RECORD,,