x-transformers 2.3.20__py3-none-any.whl → 2.3.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +11 -2
- {x_transformers-2.3.20.dist-info → x_transformers-2.3.21.dist-info}/METADATA +1 -1
- {x_transformers-2.3.20.dist-info → x_transformers-2.3.21.dist-info}/RECORD +5 -5
- {x_transformers-2.3.20.dist-info → x_transformers-2.3.21.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.20.dist-info → x_transformers-2.3.21.dist-info}/licenses/LICENSE +0 -0
x_transformers/continuous.py
CHANGED
@@ -141,6 +141,8 @@ class ContinuousTransformerWrapper(Module):
|
|
141
141
|
sum_embeds = None,
|
142
142
|
prepend_embeds = None,
|
143
143
|
prepend_mask = None,
|
144
|
+
cache: LayerIntermediates | None = None,
|
145
|
+
input_not_include_cache = False,
|
144
146
|
seq_start_pos = None,
|
145
147
|
**kwargs
|
146
148
|
):
|
@@ -154,10 +156,17 @@ class ContinuousTransformerWrapper(Module):
|
|
154
156
|
|
155
157
|
mask = einx.less('j, i -> i j', seq_arange, lens)
|
156
158
|
|
159
|
+
# take care of position embedding offsets in the presence of cache and sequence is less than cache length (not full sequence)
|
160
|
+
|
161
|
+
seq_pos_offset = 0
|
162
|
+
|
163
|
+
if exists(cache) and input_not_include_cache:
|
164
|
+
seq_pos_offset = cache.cache_length
|
165
|
+
|
157
166
|
# project in + positional embedding
|
158
167
|
|
159
168
|
x = self.project_in(x)
|
160
|
-
x = x + self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos)
|
169
|
+
x = x + self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos, offset = seq_pos_offset)
|
161
170
|
|
162
171
|
if exists(sum_embeds):
|
163
172
|
x = x + sum_embeds
|
@@ -193,7 +202,7 @@ class ContinuousTransformerWrapper(Module):
|
|
193
202
|
|
194
203
|
# attention layers
|
195
204
|
|
196
|
-
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, return_hiddens = True, **kwargs)
|
205
|
+
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, input_not_include_cache = input_not_include_cache, seq_pos_offset = seq_pos_offset, return_hiddens = True, **kwargs)
|
197
206
|
|
198
207
|
# splice out memory tokens
|
199
208
|
|
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
|
|
2
2
|
x_transformers/attend.py,sha256=fXMuwHuBAFB4f4_U6j5_uVeK7N4cV0PDd6UTqtkjKKM,17333
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
11
11
|
x_transformers/x_transformers.py,sha256=l2p-r0iJNlYHUB3vM4lb6ptzNCx9HgA7UfgieEcQT6w,115521
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.21.dist-info/METADATA,sha256=530_RGFGFlDyKIV6vMGqjGGw0f3gpArBbwNBHai_LQs,89897
|
15
|
+
x_transformers-2.3.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.21.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.21.dist-info/RECORD,,
|
File without changes
|
File without changes
|