x-transformers 1.42.24__py3-none-any.whl → 1.42.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -51,8 +51,8 @@ def default(val, d):
51
51
  return val
52
52
  return d() if callable(d) else d
53
53
 
54
- def first(it):
55
- return it[0]
54
+ def first(it, default = None):
55
+ return it[0] if len(it) > 0 else default
56
56
 
57
57
  def is_empty(x):
58
58
  return len(x) == 0
@@ -1284,6 +1284,7 @@ class Attention(Module):
1284
1284
  rel_pos = None,
1285
1285
  attn_bias = None,
1286
1286
  rotary_pos_emb = None,
1287
+ context_rotary_pos_emb = None,
1287
1288
  pos = None, # for custom alibi positions
1288
1289
  prev_attn = None,
1289
1290
  mem = None,
@@ -1355,11 +1356,18 @@ class Attention(Module):
1355
1356
  q = q * self.qk_norm_q_scale
1356
1357
  k = k * self.qk_norm_k_scale
1357
1358
 
1358
- if exists(rotary_pos_emb) and not has_context:
1359
+ if exists(rotary_pos_emb):
1359
1360
  freqs, xpos_scale = rotary_pos_emb
1360
1361
  q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1361
1362
 
1362
1363
  q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
1364
+
1365
+ if has_context:
1366
+ # override with `context_rotary_pos_emb` if provided
1367
+
1368
+ freqs, xpos_scale = context_rotary_pos_emb
1369
+ _, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1370
+
1363
1371
  k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
1364
1372
 
1365
1373
  if self.rotary_embed_values:
@@ -1848,7 +1856,6 @@ class AttentionLayers(Module):
1848
1856
  layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
1849
1857
  is_first_self_attn = False
1850
1858
  elif layer_type == 'c':
1851
- cross_attn_learned_value_residual = learned_value_residual_mix and not is_first_cross_attn
1852
1859
  layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
1853
1860
  is_first_cross_attn = False
1854
1861
  elif layer_type == 'f':
@@ -1917,6 +1924,7 @@ class AttentionLayers(Module):
1917
1924
  return_hiddens = False,
1918
1925
  rotary_pos_emb = None,
1919
1926
  pos = None,
1927
+ context_pos = None,
1920
1928
  attn_bias = None,
1921
1929
  condition = None,
1922
1930
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
@@ -1976,14 +1984,28 @@ class AttentionLayers(Module):
1976
1984
 
1977
1985
  # rotary positions
1978
1986
 
1979
- if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
1980
- maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
1981
- mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
1987
+ cross_attn_rotary_pos_emb = dict()
1988
+
1989
+ if exists(self.rotary_pos_emb):
1990
+ if not exists(rotary_pos_emb):
1991
+ maybe_mem = first(mems, None) # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
1992
+ mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
1993
+
1994
+ if not exists(pos):
1995
+ pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
1996
+
1997
+ rotary_pos_emb = self.rotary_pos_emb(pos)
1998
+
1999
+ # allow for rotary positions for context if provided
1982
2000
 
1983
- if not exists(pos):
1984
- pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
2001
+ if exists(context_pos):
2002
+ assert self.cross_attend
2003
+ context_rotary_pos_emb = self.rotary_pos_emb(context_pos)
1985
2004
 
1986
- rotary_pos_emb = self.rotary_pos_emb(pos)
2005
+ cross_attn_rotary_pos_emb.update(
2006
+ rotary_pos_emb = rotary_pos_emb,
2007
+ context_rotary_pos_emb = context_rotary_pos_emb
2008
+ )
1987
2009
 
1988
2010
  # assume cached key / values
1989
2011
 
@@ -2108,7 +2130,7 @@ class AttentionLayers(Module):
2108
2130
  if layer_type == 'a':
2109
2131
  out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2110
2132
  elif layer_type == 'c':
2111
- out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
2133
+ out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2112
2134
  elif layer_type == 'f':
2113
2135
  out = block(x)
2114
2136
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.24
3
+ Version: 1.42.26
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=yaC5Jh2sXDRADTjUZHkrJmcJmb4s-aWjrbamVQLAv0s,95928
9
+ x_transformers/x_transformers.py,sha256=iocWSfj6h0GvmvXBIRBqzZm8l4IZDf4so0pOcIy3jFg,96696
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.24.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.24.dist-info/METADATA,sha256=6gq8sWjWzyazL_0CCyfN05PMNxApuNNLu2AeN3sGYkA,739
14
- x_transformers-1.42.24.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.42.24.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.24.dist-info/RECORD,,
12
+ x_transformers-1.42.26.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.26.dist-info/METADATA,sha256=cf7t334edZKkTEpnwlhTQNYRgLDU6SSBPxtsA58ubxY,739
14
+ x_transformers-1.42.26.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
+ x_transformers-1.42.26.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.26.dist-info/RECORD,,