x-transformers 1.27.3__py3-none-any.whl → 1.27.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +23 -5
- {x_transformers-1.27.3.dist-info → x_transformers-1.27.4.dist-info}/METADATA +1 -1
- {x_transformers-1.27.3.dist-info → x_transformers-1.27.4.dist-info}/RECORD +6 -6
- {x_transformers-1.27.3.dist-info → x_transformers-1.27.4.dist-info}/LICENSE +0 -0
- {x_transformers-1.27.3.dist-info → x_transformers-1.27.4.dist-info}/WHEEL +0 -0
- {x_transformers-1.27.3.dist-info → x_transformers-1.27.4.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -813,6 +813,7 @@ class Attention(nn.Module):
|
|
813
813
|
rotary_pos_emb = None,
|
814
814
|
prev_attn = None,
|
815
815
|
mem = None,
|
816
|
+
mem_mask = None,
|
816
817
|
return_intermediates = False,
|
817
818
|
cache: Optional[Intermediates] = None,
|
818
819
|
):
|
@@ -879,8 +880,15 @@ class Attention(nn.Module):
|
|
879
880
|
if not exists(input_mask) and not has_context:
|
880
881
|
input_mask = mask
|
881
882
|
|
882
|
-
if exists(input_mask) and exists(mem):
|
883
|
-
|
883
|
+
if (exists(input_mask) or exists(mem_mask)) and exists(mem):
|
884
|
+
seq_len, mem_len = n, mem.shape[-2]
|
885
|
+
|
886
|
+
if not exists(mem_mask):
|
887
|
+
input_mask = pad_at_dim(input_mask, (mem_len, 0), dim = -1, value = True)
|
888
|
+
elif not exists(input_mask):
|
889
|
+
input_mask = pad_at_dim(mem_mask, (0, seq_len), dim = -1, value = True)
|
890
|
+
else:
|
891
|
+
input_mask = torch.cat((mem_mask, input_mask), dim = -1)
|
884
892
|
|
885
893
|
if self.num_mem_kv > 0:
|
886
894
|
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
|
@@ -1221,6 +1229,7 @@ class AttentionLayers(nn.Module):
|
|
1221
1229
|
attn_mask = None,
|
1222
1230
|
self_attn_kv_mask = None,
|
1223
1231
|
mems = None,
|
1232
|
+
mem_masks = None,
|
1224
1233
|
seq_start_pos: Optional[Tensor] = None,
|
1225
1234
|
cache: Optional[LayerIntermediates] = None,
|
1226
1235
|
cache_age = 1,
|
@@ -1239,6 +1248,7 @@ class AttentionLayers(nn.Module):
|
|
1239
1248
|
prev_cross_attn = None
|
1240
1249
|
|
1241
1250
|
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
1251
|
+
mem_masks = mem_masks.copy() if exists(mem_masks) else [None] * self.num_attn_layers
|
1242
1252
|
|
1243
1253
|
# handle left padded sequences
|
1244
1254
|
|
@@ -1255,7 +1265,12 @@ class AttentionLayers(nn.Module):
|
|
1255
1265
|
|
1256
1266
|
if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
|
1257
1267
|
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
|
1258
|
-
|
1268
|
+
|
1269
|
+
maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
|
1270
|
+
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
|
1271
|
+
|
1272
|
+
pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
|
1273
|
+
rotary_pos_emb = self.rotary_pos_emb(pos)
|
1259
1274
|
|
1260
1275
|
# assume cached key / values
|
1261
1276
|
|
@@ -1296,7 +1311,9 @@ class AttentionLayers(nn.Module):
|
|
1296
1311
|
if layer_type == 'a':
|
1297
1312
|
if return_hiddens:
|
1298
1313
|
hiddens.append(x)
|
1314
|
+
|
1299
1315
|
layer_mem = mems.pop(0) if mems else None
|
1316
|
+
layer_mem_mask = mem_masks.pop(0) if mem_masks else None
|
1300
1317
|
|
1301
1318
|
if layer_type == 'c':
|
1302
1319
|
if self.training and self.cross_attn_tokens_dropout > 0.:
|
@@ -1313,7 +1330,7 @@ class AttentionLayers(nn.Module):
|
|
1313
1330
|
x = pre_norm(x)
|
1314
1331
|
|
1315
1332
|
if layer_type == 'a':
|
1316
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, return_intermediates = True)
|
1333
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True)
|
1317
1334
|
elif layer_type == 'c':
|
1318
1335
|
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
|
1319
1336
|
elif layer_type == 'f':
|
@@ -1576,6 +1593,7 @@ class TransformerWrapper(nn.Module):
|
|
1576
1593
|
return_mems = False,
|
1577
1594
|
return_attn = False,
|
1578
1595
|
mems = None,
|
1596
|
+
mem_masks = None,
|
1579
1597
|
pos = None,
|
1580
1598
|
prepend_embeds = None,
|
1581
1599
|
prepend_mask = None,
|
@@ -1669,7 +1687,7 @@ class TransformerWrapper(nn.Module):
|
|
1669
1687
|
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
|
1670
1688
|
mems = [*mems_r, *mems_l]
|
1671
1689
|
|
1672
|
-
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
1690
|
+
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
1673
1691
|
|
1674
1692
|
if has_memory_tokens:
|
1675
1693
|
if exists(mem_every):
|
@@ -3,11 +3,11 @@ x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,1018
|
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
|
4
4
|
x_transformers/continuous.py,sha256=SAZGR-3BgXU7OEQtjg1_9FnrUBpIyVfXfpMrH-oL5rU,6074
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=4ggmuqOPhVYE-yXHvfH7ihPJH6kbV-FpqtbvHEhYeKg,63289
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
8
|
x_transformers/xval.py,sha256=ulEPep6i5Hl7H-H9vGfdsmHdprUmK8ajB306jViyV2c,8147
|
9
|
-
x_transformers-1.27.
|
10
|
-
x_transformers-1.27.
|
11
|
-
x_transformers-1.27.
|
12
|
-
x_transformers-1.27.
|
13
|
-
x_transformers-1.27.
|
9
|
+
x_transformers-1.27.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
10
|
+
x_transformers-1.27.4.dist-info/METADATA,sha256=XMUCtTrKS4EHqYj_B3nW8aQDUIJkI_h8_mJXbErxCjI,661
|
11
|
+
x_transformers-1.27.4.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
12
|
+
x_transformers-1.27.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
13
|
+
x_transformers-1.27.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|