x-transformers 2.1.12__py3-none-any.whl → 2.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
 
2
2
  # Belief State Transformer
3
3
 
4
- # https://arxiv.org/abs/2410.23506
4
+ # Hu et al. https://arxiv.org/abs/2410.23506
5
5
  # https://www.youtube.com/watch?v=aqhbRtB2Fyg
6
6
 
7
7
  from __future__ import annotations
@@ -107,7 +107,7 @@ class BeliefStateWrapper(Module):
107
107
 
108
108
  @torch.no_grad()
109
109
  @eval_decorator
110
- def generate_with_suffix_token_only(
110
+ def generate_with_suffix_cond(
111
111
  self,
112
112
  prompts,
113
113
  seq_len,
@@ -132,8 +132,11 @@ class BeliefStateWrapper(Module):
132
132
 
133
133
  # get the encoded suffix token once
134
134
 
135
- if exists(suffix) and suffix.ndim == 1:
136
- suffix = repeat(suffix, 'n -> b n', b = batch)
135
+ if exists(suffix):
136
+ if suffix.ndim == 1:
137
+ suffix = repeat(suffix, 'n -> b n', b = batch)
138
+
139
+ suffix = suffix.flip(1) # reverse autoregressive
137
140
 
138
141
  suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
139
142
 
@@ -145,6 +148,10 @@ class BeliefStateWrapper(Module):
145
148
  return_embeddings = True
146
149
  )
147
150
 
151
+ # pick out the last embedding for fill in the middle
152
+
153
+ suffix_embed = suffix_embed[:, -1:]
154
+
148
155
  # sampling up to seq_len
149
156
 
150
157
  for _ in range(seq_len):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.12
3
+ Version: 2.1.15
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=aMmekjRNHem-4MKXTK8z_u0497EThUhvKLISwaKbqQw,8665
4
+ x_transformers/belief_state_wrapper.py,sha256=LclzwJ4FjfRh4b68Y1IJsWsmo2ymffbutnOqfTg-LdM,8854
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.12.dist-info/METADATA,sha256=JXQYWgfNcv43jVmFY_FAiIhD5EYvX88BA4zJXeRMxa0,87571
14
- x_transformers-2.1.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.12.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.12.dist-info/RECORD,,
13
+ x_transformers-2.1.15.dist-info/METADATA,sha256=zV-K3eS2O8ld1ph7JqTW31EcqFasoEh625OPXHz2N78,87571
14
+ x_transformers-2.1.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.15.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.15.dist-info/RECORD,,