x-transformers 1.42.24__py3-none-any.whl → 1.42.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +33 -11
 - {x_transformers-1.42.24.dist-info → x_transformers-1.42.26.dist-info}/METADATA +1 -1
 - {x_transformers-1.42.24.dist-info → x_transformers-1.42.26.dist-info}/RECORD +6 -6
 - {x_transformers-1.42.24.dist-info → x_transformers-1.42.26.dist-info}/LICENSE +0 -0
 - {x_transformers-1.42.24.dist-info → x_transformers-1.42.26.dist-info}/WHEEL +0 -0
 - {x_transformers-1.42.24.dist-info → x_transformers-1.42.26.dist-info}/top_level.txt +0 -0
 
    
        x_transformers/x_transformers.py
    CHANGED
    
    | 
         @@ -51,8 +51,8 @@ def default(val, d): 
     | 
|
| 
       51 
51 
     | 
    
         
             
                    return val
         
     | 
| 
       52 
52 
     | 
    
         
             
                return d() if callable(d) else d
         
     | 
| 
       53 
53 
     | 
    
         | 
| 
       54 
     | 
    
         
            -
            def first(it):
         
     | 
| 
       55 
     | 
    
         
            -
                return it[0]
         
     | 
| 
      
 54 
     | 
    
         
            +
            def first(it, default = None):
         
     | 
| 
      
 55 
     | 
    
         
            +
                return it[0] if len(it) > 0 else default
         
     | 
| 
       56 
56 
     | 
    
         | 
| 
       57 
57 
     | 
    
         
             
            def is_empty(x):
         
     | 
| 
       58 
58 
     | 
    
         
             
                return len(x) == 0
         
     | 
| 
         @@ -1284,6 +1284,7 @@ class Attention(Module): 
     | 
|
| 
       1284 
1284 
     | 
    
         
             
                    rel_pos = None,
         
     | 
| 
       1285 
1285 
     | 
    
         
             
                    attn_bias = None,
         
     | 
| 
       1286 
1286 
     | 
    
         
             
                    rotary_pos_emb = None,
         
     | 
| 
      
 1287 
     | 
    
         
            +
                    context_rotary_pos_emb = None,
         
     | 
| 
       1287 
1288 
     | 
    
         
             
                    pos = None, # for custom alibi positions
         
     | 
| 
       1288 
1289 
     | 
    
         
             
                    prev_attn = None,
         
     | 
| 
       1289 
1290 
     | 
    
         
             
                    mem = None,
         
     | 
| 
         @@ -1355,11 +1356,18 @@ class Attention(Module): 
     | 
|
| 
       1355 
1356 
     | 
    
         
             
                        q = q * self.qk_norm_q_scale
         
     | 
| 
       1356 
1357 
     | 
    
         
             
                        k = k * self.qk_norm_k_scale
         
     | 
| 
       1357 
1358 
     | 
    
         | 
| 
       1358 
     | 
    
         
            -
                    if exists(rotary_pos_emb) 
     | 
| 
      
 1359 
     | 
    
         
            +
                    if exists(rotary_pos_emb):
         
     | 
| 
       1359 
1360 
     | 
    
         
             
                        freqs, xpos_scale = rotary_pos_emb
         
     | 
| 
       1360 
1361 
     | 
    
         
             
                        q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
         
     | 
| 
       1361 
1362 
     | 
    
         | 
| 
       1362 
1363 
     | 
    
         
             
                        q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
         
     | 
| 
      
 1364 
     | 
    
         
            +
             
     | 
| 
      
 1365 
     | 
    
         
            +
                        if has_context:
         
     | 
| 
      
 1366 
     | 
    
         
            +
                            # override with `context_rotary_pos_emb` if provided
         
     | 
| 
      
 1367 
     | 
    
         
            +
             
     | 
| 
      
 1368 
     | 
    
         
            +
                            freqs, xpos_scale = context_rotary_pos_emb
         
     | 
| 
      
 1369 
     | 
    
         
            +
                            _, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
         
     | 
| 
      
 1370 
     | 
    
         
            +
             
     | 
| 
       1363 
1371 
     | 
    
         
             
                        k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
         
     | 
| 
       1364 
1372 
     | 
    
         | 
| 
       1365 
1373 
     | 
    
         
             
                        if self.rotary_embed_values:
         
     | 
| 
         @@ -1848,7 +1856,6 @@ class AttentionLayers(Module): 
     | 
|
| 
       1848 
1856 
     | 
    
         
             
                            layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
         
     | 
| 
       1849 
1857 
     | 
    
         
             
                            is_first_self_attn = False
         
     | 
| 
       1850 
1858 
     | 
    
         
             
                        elif layer_type == 'c':
         
     | 
| 
       1851 
     | 
    
         
            -
                            cross_attn_learned_value_residual = learned_value_residual_mix and not is_first_cross_attn
         
     | 
| 
       1852 
1859 
     | 
    
         
             
                            layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
         
     | 
| 
       1853 
1860 
     | 
    
         
             
                            is_first_cross_attn = False
         
     | 
| 
       1854 
1861 
     | 
    
         
             
                        elif layer_type == 'f':
         
     | 
| 
         @@ -1917,6 +1924,7 @@ class AttentionLayers(Module): 
     | 
|
| 
       1917 
1924 
     | 
    
         
             
                    return_hiddens = False,
         
     | 
| 
       1918 
1925 
     | 
    
         
             
                    rotary_pos_emb = None,
         
     | 
| 
       1919 
1926 
     | 
    
         
             
                    pos = None,
         
     | 
| 
      
 1927 
     | 
    
         
            +
                    context_pos = None,
         
     | 
| 
       1920 
1928 
     | 
    
         
             
                    attn_bias = None,
         
     | 
| 
       1921 
1929 
     | 
    
         
             
                    condition = None,
         
     | 
| 
       1922 
1930 
     | 
    
         
             
                    in_attn_cond = None, # https://arxiv.org/abs/2105.04090
         
     | 
| 
         @@ -1976,14 +1984,28 @@ class AttentionLayers(Module): 
     | 
|
| 
       1976 
1984 
     | 
    
         | 
| 
       1977 
1985 
     | 
    
         
             
                    # rotary positions
         
     | 
| 
       1978 
1986 
     | 
    
         | 
| 
       1979 
     | 
    
         
            -
                     
     | 
| 
       1980 
     | 
    
         
            -
             
     | 
| 
       1981 
     | 
    
         
            -
             
     | 
| 
      
 1987 
     | 
    
         
            +
                    cross_attn_rotary_pos_emb = dict()
         
     | 
| 
      
 1988 
     | 
    
         
            +
             
     | 
| 
      
 1989 
     | 
    
         
            +
                    if exists(self.rotary_pos_emb):
         
     | 
| 
      
 1990 
     | 
    
         
            +
                        if not exists(rotary_pos_emb):
         
     | 
| 
      
 1991 
     | 
    
         
            +
                            maybe_mem = first(mems, None) # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
         
     | 
| 
      
 1992 
     | 
    
         
            +
                            mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
         
     | 
| 
      
 1993 
     | 
    
         
            +
             
     | 
| 
      
 1994 
     | 
    
         
            +
                            if not exists(pos):
         
     | 
| 
      
 1995 
     | 
    
         
            +
                                pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
         
     | 
| 
      
 1996 
     | 
    
         
            +
             
     | 
| 
      
 1997 
     | 
    
         
            +
                            rotary_pos_emb = self.rotary_pos_emb(pos)
         
     | 
| 
      
 1998 
     | 
    
         
            +
             
     | 
| 
      
 1999 
     | 
    
         
            +
                        # allow for rotary positions for context if provided
         
     | 
| 
       1982 
2000 
     | 
    
         | 
| 
       1983 
     | 
    
         
            -
                        if  
     | 
| 
       1984 
     | 
    
         
            -
                             
     | 
| 
      
 2001 
     | 
    
         
            +
                        if exists(context_pos):
         
     | 
| 
      
 2002 
     | 
    
         
            +
                            assert self.cross_attend
         
     | 
| 
      
 2003 
     | 
    
         
            +
                            context_rotary_pos_emb = self.rotary_pos_emb(context_pos)
         
     | 
| 
       1985 
2004 
     | 
    
         | 
| 
       1986 
     | 
    
         
            -
             
     | 
| 
      
 2005 
     | 
    
         
            +
                            cross_attn_rotary_pos_emb.update(
         
     | 
| 
      
 2006 
     | 
    
         
            +
                                rotary_pos_emb = rotary_pos_emb,
         
     | 
| 
      
 2007 
     | 
    
         
            +
                                context_rotary_pos_emb = context_rotary_pos_emb
         
     | 
| 
      
 2008 
     | 
    
         
            +
                            )
         
     | 
| 
       1987 
2009 
     | 
    
         | 
| 
       1988 
2010 
     | 
    
         
             
                    # assume cached key / values
         
     | 
| 
       1989 
2011 
     | 
    
         | 
| 
         @@ -2108,7 +2130,7 @@ class AttentionLayers(Module): 
     | 
|
| 
       2108 
2130 
     | 
    
         
             
                        if layer_type == 'a':
         
     | 
| 
       2109 
2131 
     | 
    
         
             
                            out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
         
     | 
| 
       2110 
2132 
     | 
    
         
             
                        elif layer_type == 'c':
         
     | 
| 
       2111 
     | 
    
         
            -
                            out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
         
     | 
| 
      
 2133 
     | 
    
         
            +
                            out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
         
     | 
| 
       2112 
2134 
     | 
    
         
             
                        elif layer_type == 'f':
         
     | 
| 
       2113 
2135 
     | 
    
         
             
                            out = block(x)
         
     | 
| 
       2114 
2136 
     | 
    
         | 
| 
         @@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521 
     | 
|
| 
       6 
6 
     | 
    
         
             
            x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
         
     | 
| 
       7 
7 
     | 
    
         
             
            x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
         
     | 
| 
       8 
8 
     | 
    
         
             
            x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
         
     | 
| 
       9 
     | 
    
         
            -
            x_transformers/x_transformers.py,sha256= 
     | 
| 
      
 9 
     | 
    
         
            +
            x_transformers/x_transformers.py,sha256=iocWSfj6h0GvmvXBIRBqzZm8l4IZDf4so0pOcIy3jFg,96696
         
     | 
| 
       10 
10 
     | 
    
         
             
            x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
         
     | 
| 
       11 
11 
     | 
    
         
             
            x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
         
     | 
| 
       12 
     | 
    
         
            -
            x_transformers-1.42. 
     | 
| 
       13 
     | 
    
         
            -
            x_transformers-1.42. 
     | 
| 
       14 
     | 
    
         
            -
            x_transformers-1.42. 
     | 
| 
       15 
     | 
    
         
            -
            x_transformers-1.42. 
     | 
| 
       16 
     | 
    
         
            -
            x_transformers-1.42. 
     | 
| 
      
 12 
     | 
    
         
            +
            x_transformers-1.42.26.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
         
     | 
| 
      
 13 
     | 
    
         
            +
            x_transformers-1.42.26.dist-info/METADATA,sha256=cf7t334edZKkTEpnwlhTQNYRgLDU6SSBPxtsA58ubxY,739
         
     | 
| 
      
 14 
     | 
    
         
            +
            x_transformers-1.42.26.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
         
     | 
| 
      
 15 
     | 
    
         
            +
            x_transformers-1.42.26.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
         
     | 
| 
      
 16 
     | 
    
         
            +
            x_transformers-1.42.26.dist-info/RECORD,,
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |