x-transformers 1.27.2__py3-none-any.whl → 1.27.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -48,7 +48,7 @@ class ContinuousTransformerWrapper(nn.Module):
48
48
 
49
49
  self.max_mem_len = max_mem_len
50
50
 
51
- if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
51
+ if not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb):
52
52
  self.pos_emb = always(0)
53
53
  elif scaled_sinu_pos_emb:
54
54
  self.pos_emb = ScaledSinusoidalEmbedding(dim)
@@ -813,6 +813,7 @@ class Attention(nn.Module):
813
813
  rotary_pos_emb = None,
814
814
  prev_attn = None,
815
815
  mem = None,
816
+ mem_mask = None,
816
817
  return_intermediates = False,
817
818
  cache: Optional[Intermediates] = None,
818
819
  ):
@@ -879,8 +880,15 @@ class Attention(nn.Module):
879
880
  if not exists(input_mask) and not has_context:
880
881
  input_mask = mask
881
882
 
882
- if exists(input_mask) and exists(mem):
883
- input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = True)
883
+ if (exists(input_mask) or exists(mem_mask)) and exists(mem):
884
+ seq_len, mem_len = n, mem.shape[-2]
885
+
886
+ if not exists(mem_mask):
887
+ input_mask = pad_at_dim(input_mask, (mem_len, 0), dim = -1, value = True)
888
+ elif not exists(input_mask):
889
+ input_mask = pad_at_dim(mem_mask, (0, seq_len), dim = -1, value = True)
890
+ else:
891
+ input_mask = torch.cat((mem_mask, input_mask), dim = -1)
884
892
 
885
893
  if self.num_mem_kv > 0:
886
894
  mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
@@ -1221,6 +1229,7 @@ class AttentionLayers(nn.Module):
1221
1229
  attn_mask = None,
1222
1230
  self_attn_kv_mask = None,
1223
1231
  mems = None,
1232
+ mem_masks = None,
1224
1233
  seq_start_pos: Optional[Tensor] = None,
1225
1234
  cache: Optional[LayerIntermediates] = None,
1226
1235
  cache_age = 1,
@@ -1239,6 +1248,7 @@ class AttentionLayers(nn.Module):
1239
1248
  prev_cross_attn = None
1240
1249
 
1241
1250
  mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
1251
+ mem_masks = mem_masks.copy() if exists(mem_masks) else [None] * self.num_attn_layers
1242
1252
 
1243
1253
  # handle left padded sequences
1244
1254
 
@@ -1255,7 +1265,12 @@ class AttentionLayers(nn.Module):
1255
1265
 
1256
1266
  if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
1257
1267
  max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
1258
- rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(max_rotary_emb_length)
1268
+
1269
+ maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
1270
+ mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
1271
+
1272
+ pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
1273
+ rotary_pos_emb = self.rotary_pos_emb(pos)
1259
1274
 
1260
1275
  # assume cached key / values
1261
1276
 
@@ -1296,7 +1311,9 @@ class AttentionLayers(nn.Module):
1296
1311
  if layer_type == 'a':
1297
1312
  if return_hiddens:
1298
1313
  hiddens.append(x)
1314
+
1299
1315
  layer_mem = mems.pop(0) if mems else None
1316
+ layer_mem_mask = mem_masks.pop(0) if mem_masks else None
1300
1317
 
1301
1318
  if layer_type == 'c':
1302
1319
  if self.training and self.cross_attn_tokens_dropout > 0.:
@@ -1313,7 +1330,7 @@ class AttentionLayers(nn.Module):
1313
1330
  x = pre_norm(x)
1314
1331
 
1315
1332
  if layer_type == 'a':
1316
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, return_intermediates = True)
1333
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True)
1317
1334
  elif layer_type == 'c':
1318
1335
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
1319
1336
  elif layer_type == 'f':
@@ -1576,6 +1593,7 @@ class TransformerWrapper(nn.Module):
1576
1593
  return_mems = False,
1577
1594
  return_attn = False,
1578
1595
  mems = None,
1596
+ mem_masks = None,
1579
1597
  pos = None,
1580
1598
  prepend_embeds = None,
1581
1599
  prepend_mask = None,
@@ -1669,7 +1687,7 @@ class TransformerWrapper(nn.Module):
1669
1687
  mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
1670
1688
  mems = [*mems_r, *mems_l]
1671
1689
 
1672
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
1690
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
1673
1691
 
1674
1692
  if has_memory_tokens:
1675
1693
  if exists(mem_every):
x_transformers/xval.py CHANGED
@@ -74,7 +74,7 @@ class XValTransformerWrapper(nn.Module):
74
74
 
75
75
  self.max_mem_len = max_mem_len
76
76
 
77
- if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
77
+ if not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb):
78
78
  self.pos_emb = always(0)
79
79
  elif scaled_sinu_pos_emb:
80
80
  self.pos_emb = ScaledSinusoidalEmbedding(dim)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.27.2
3
+ Version: 1.27.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -0,0 +1,13 @@
1
+ x_transformers/__init__.py,sha256=pXc_U4M3ONUQcpNgZySDIlCF1rp7u4FFmcOYjc4WuXw,629
2
+ x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
3
+ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
4
+ x_transformers/continuous.py,sha256=SAZGR-3BgXU7OEQtjg1_9FnrUBpIyVfXfpMrH-oL5rU,6074
5
+ x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
+ x_transformers/x_transformers.py,sha256=4ggmuqOPhVYE-yXHvfH7ihPJH6kbV-FpqtbvHEhYeKg,63289
7
+ x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
+ x_transformers/xval.py,sha256=ulEPep6i5Hl7H-H9vGfdsmHdprUmK8ajB306jViyV2c,8147
9
+ x_transformers-1.27.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
+ x_transformers-1.27.4.dist-info/METADATA,sha256=XMUCtTrKS4EHqYj_B3nW8aQDUIJkI_h8_mJXbErxCjI,661
11
+ x_transformers-1.27.4.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
12
+ x_transformers-1.27.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
+ x_transformers-1.27.4.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- x_transformers/__init__.py,sha256=pXc_U4M3ONUQcpNgZySDIlCF1rp7u4FFmcOYjc4WuXw,629
2
- x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
3
- x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
4
- x_transformers/continuous.py,sha256=Ra5IClCl9G7SAiM6L9w6iY4cCznH0dSGljC9AC_bNyw,6066
5
- x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=dgR0FlSnpkcC52rJ4BNcWP0q5Q00nPGY2UAhIk9VvSA,62371
7
- x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
- x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
9
- x_transformers-1.27.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
- x_transformers-1.27.2.dist-info/METADATA,sha256=UEysy_lZFZfTM6RuFQ2O4g8G_NAzoN_lihQzCN0DXkg,661
11
- x_transformers-1.27.2.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
12
- x_transformers-1.27.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
- x_transformers-1.27.2.dist-info/RECORD,,