x-transformers 1.42.24__py3-none-any.whl → 1.42.25__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1284,6 +1284,7 @@ class Attention(Module):
1284
1284
  rel_pos = None,
1285
1285
  attn_bias = None,
1286
1286
  rotary_pos_emb = None,
1287
+ context_rotary_pos_emb = None,
1287
1288
  pos = None, # for custom alibi positions
1288
1289
  prev_attn = None,
1289
1290
  mem = None,
@@ -1355,11 +1356,19 @@ class Attention(Module):
1355
1356
  q = q * self.qk_norm_q_scale
1356
1357
  k = k * self.qk_norm_k_scale
1357
1358
 
1358
- if exists(rotary_pos_emb) and not has_context:
1359
+ if exists(rotary_pos_emb):
1360
+
1359
1361
  freqs, xpos_scale = rotary_pos_emb
1360
1362
  q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1361
1363
 
1362
1364
  q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
1365
+
1366
+ if has_context:
1367
+ # override with `context_rotary_pos_emb` if provided
1368
+
1369
+ freqs, xpos_scale = context_rotary_pos_emb
1370
+ _, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1371
+
1363
1372
  k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
1364
1373
 
1365
1374
  if self.rotary_embed_values:
@@ -1848,7 +1857,6 @@ class AttentionLayers(Module):
1848
1857
  layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
1849
1858
  is_first_self_attn = False
1850
1859
  elif layer_type == 'c':
1851
- cross_attn_learned_value_residual = learned_value_residual_mix and not is_first_cross_attn
1852
1860
  layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
1853
1861
  is_first_cross_attn = False
1854
1862
  elif layer_type == 'f':
@@ -1917,6 +1925,7 @@ class AttentionLayers(Module):
1917
1925
  return_hiddens = False,
1918
1926
  rotary_pos_emb = None,
1919
1927
  pos = None,
1928
+ context_pos = None,
1920
1929
  attn_bias = None,
1921
1930
  condition = None,
1922
1931
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
@@ -1976,14 +1985,28 @@ class AttentionLayers(Module):
1976
1985
 
1977
1986
  # rotary positions
1978
1987
 
1979
- if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
1980
- maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
1981
- mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
1988
+ cross_attn_rotary_pos_emb = dict()
1989
+
1990
+ if exists(self.rotary_pos_emb):
1991
+ if not exists(rotary_pos_emb):
1992
+ maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
1993
+ mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
1994
+
1995
+ if not exists(pos):
1996
+ pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
1997
+
1998
+ rotary_pos_emb = self.rotary_pos_emb(pos)
1999
+
2000
+ # allow for rotary positions for context if provided
1982
2001
 
1983
- if not exists(pos):
1984
- pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
2002
+ if exists(context_pos):
2003
+ assert self.cross_attend
2004
+ context_rotary_pos_emb = self.rotary_pos_emb(context_pos)
1985
2005
 
1986
- rotary_pos_emb = self.rotary_pos_emb(pos)
2006
+ cross_attn_rotary_pos_emb.update(
2007
+ rotary_pos_emb = rotary_pos_emb,
2008
+ context_rotary_pos_emb = context_rotary_pos_emb
2009
+ )
1987
2010
 
1988
2011
  # assume cached key / values
1989
2012
 
@@ -2108,7 +2131,7 @@ class AttentionLayers(Module):
2108
2131
  if layer_type == 'a':
2109
2132
  out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2110
2133
  elif layer_type == 'c':
2111
- out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
2134
+ out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2112
2135
  elif layer_type == 'f':
2113
2136
  out = block(x)
2114
2137
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.24
3
+ Version: 1.42.25
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=yaC5Jh2sXDRADTjUZHkrJmcJmb4s-aWjrbamVQLAv0s,95928
9
+ x_transformers/x_transformers.py,sha256=tj4s_p46Up89RcIFJF4aZ4iWtt4fpDVHKHqXv23Oekk,96643
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.24.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.24.dist-info/METADATA,sha256=6gq8sWjWzyazL_0CCyfN05PMNxApuNNLu2AeN3sGYkA,739
14
- x_transformers-1.42.24.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.42.24.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.24.dist-info/RECORD,,
12
+ x_transformers-1.42.25.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.25.dist-info/METADATA,sha256=I2JJliI_WRW_0_tQoigduXIaYgDcU4YGdxJJKJ62BHE,739
14
+ x_transformers-1.42.25.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
+ x_transformers-1.42.25.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.25.dist-info/RECORD,,