x-transformers 1.26.4__py3-none-any.whl → 1.26.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/autoregressive_wrapper.py +4 -0
- x_transformers/x_transformers.py +4 -1
- {x_transformers-1.26.4.dist-info → x_transformers-1.26.5.dist-info}/METADATA +1 -1
- {x_transformers-1.26.4.dist-info → x_transformers-1.26.5.dist-info}/RECORD +7 -7
- {x_transformers-1.26.4.dist-info → x_transformers-1.26.5.dist-info}/LICENSE +0 -0
- {x_transformers-1.26.4.dist-info → x_transformers-1.26.5.dist-info}/WHEEL +0 -0
- {x_transformers-1.26.4.dist-info → x_transformers-1.26.5.dist-info}/top_level.txt +0 -0
@@ -188,6 +188,10 @@ class AutoregressiveWrapper(Module):
|
|
188
188
|
for _ in range(seq_len):
|
189
189
|
|
190
190
|
if restrict_to_max_seq_len:
|
191
|
+
max_len_exceeded = out.shape[-1] > max_seq_len
|
192
|
+
|
193
|
+
assert not (cache_kv and max_len_exceeded and self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embeeding. you can switch to rotary embeddings to resolve this issue'
|
194
|
+
|
191
195
|
x = out[:, -max_seq_len:]
|
192
196
|
|
193
197
|
if exists(cache):
|
x_transformers/x_transformers.py
CHANGED
@@ -1495,7 +1495,9 @@ class TransformerWrapper(nn.Module):
|
|
1495
1495
|
self.l2norm_embed = l2norm_embed
|
1496
1496
|
self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
|
1497
1497
|
|
1498
|
-
|
1498
|
+
no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.has_pos_emb)
|
1499
|
+
|
1500
|
+
if no_abs_pos_emb:
|
1499
1501
|
self.pos_emb = always(0)
|
1500
1502
|
elif scaled_sinu_pos_emb:
|
1501
1503
|
self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
|
@@ -1536,6 +1538,7 @@ class TransformerWrapper(nn.Module):
|
|
1536
1538
|
# whether can do cached kv decoding
|
1537
1539
|
|
1538
1540
|
self.can_cache_kv = self.num_memory_tokens == 0
|
1541
|
+
self.can_cache_kv_outside_max_seq_len = not no_abs_pos_emb
|
1539
1542
|
|
1540
1543
|
def init_(self):
|
1541
1544
|
if self.l2norm_embed:
|
@@ -1,13 +1,13 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=pXc_U4M3ONUQcpNgZySDIlCF1rp7u4FFmcOYjc4WuXw,629
|
2
2
|
x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=47sc7HAMNBJUGZRtZX-cO-yML0YFcw4PF6E-7pp1E0A,9614
|
4
4
|
x_transformers/continuous.py,sha256=ixfgi2_zpGN03SX_STXFkNYEOAkgwVIxuS53QgDCx-g,6026
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=8n8R_huY0KwKDGTUlLLhleAqNR5M1YI_95KRmhrP_Eg,61740
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
8
|
x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
|
9
|
-
x_transformers-1.26.
|
10
|
-
x_transformers-1.26.
|
11
|
-
x_transformers-1.26.
|
12
|
-
x_transformers-1.26.
|
13
|
-
x_transformers-1.26.
|
9
|
+
x_transformers-1.26.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
10
|
+
x_transformers-1.26.5.dist-info/METADATA,sha256=GcEy7CtmuqOpAapRxh7Et5kfPOBiV2EIa6GjN2U-eFM,661
|
11
|
+
x_transformers-1.26.5.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
12
|
+
x_transformers-1.26.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
13
|
+
x_transformers-1.26.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|