x-transformers 2.3.20__py3-none-any.whl → 2.3.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -141,6 +141,8 @@ class ContinuousTransformerWrapper(Module):
141
141
  sum_embeds = None,
142
142
  prepend_embeds = None,
143
143
  prepend_mask = None,
144
+ cache: LayerIntermediates | None = None,
145
+ input_not_include_cache = False,
144
146
  seq_start_pos = None,
145
147
  **kwargs
146
148
  ):
@@ -154,10 +156,17 @@ class ContinuousTransformerWrapper(Module):
154
156
 
155
157
  mask = einx.less('j, i -> i j', seq_arange, lens)
156
158
 
159
+ # take care of position embedding offsets in the presence of cache and sequence is less than cache length (not full sequence)
160
+
161
+ seq_pos_offset = 0
162
+
163
+ if exists(cache) and input_not_include_cache:
164
+ seq_pos_offset = cache.cache_length
165
+
157
166
  # project in + positional embedding
158
167
 
159
168
  x = self.project_in(x)
160
- x = x + self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos)
169
+ x = x + self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos, offset = seq_pos_offset)
161
170
 
162
171
  if exists(sum_embeds):
163
172
  x = x + sum_embeds
@@ -193,7 +202,7 @@ class ContinuousTransformerWrapper(Module):
193
202
 
194
203
  # attention layers
195
204
 
196
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, return_hiddens = True, **kwargs)
205
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, input_not_include_cache = input_not_include_cache, seq_pos_offset = seq_pos_offset, return_hiddens = True, **kwargs)
197
206
 
198
207
  # splice out memory tokens
199
208
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.20
3
+ Version: 2.3.21
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
2
2
  x_transformers/attend.py,sha256=fXMuwHuBAFB4f4_U6j5_uVeK7N4cV0PDd6UTqtkjKKM,17333
3
3
  x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
- x_transformers/continuous.py,sha256=CHta8vizKl85n220fv5278fwjSU-vrN_FBy-m831_go,12551
5
+ x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
11
11
  x_transformers/x_transformers.py,sha256=l2p-r0iJNlYHUB3vM4lb6ptzNCx9HgA7UfgieEcQT6w,115521
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.20.dist-info/METADATA,sha256=ygWyfnlIh2Mw6bd12gJjjZJyM9vfnXmvvOLyrd2El2k,89897
15
- x_transformers-2.3.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.20.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.20.dist-info/RECORD,,
14
+ x_transformers-2.3.21.dist-info/METADATA,sha256=530_RGFGFlDyKIV6vMGqjGGw0f3gpArBbwNBHai_LQs,89897
15
+ x_transformers-2.3.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.21.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.21.dist-info/RECORD,,