x-transformers 2.11.24__py3-none-any.whl → 2.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- x_transformers/x_transformers.py +68 -3
- {x_transformers-2.11.24.dist-info → x_transformers-2.12.1.dist-info}/METADATA +13 -1
- {x_transformers-2.11.24.dist-info → x_transformers-2.12.1.dist-info}/RECORD +5 -5
- {x_transformers-2.11.24.dist-info → x_transformers-2.12.1.dist-info}/WHEEL +0 -0
- {x_transformers-2.11.24.dist-info → x_transformers-2.12.1.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
|
@@ -779,6 +779,49 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
|
|
|
779
779
|
|
|
780
780
|
return out.type(orig_dtype)
|
|
781
781
|
|
|
782
|
+
class PolarEmbedding(Module):
|
|
783
|
+
""" https://arxiv.org/abs/2509.10534 """
|
|
784
|
+
|
|
785
|
+
def __init__(
|
|
786
|
+
self,
|
|
787
|
+
dim,
|
|
788
|
+
bias_uniform_init = False,
|
|
789
|
+
base = 10000,
|
|
790
|
+
):
|
|
791
|
+
super().__init__()
|
|
792
|
+
inv_freq = 1. / (base ** (arange(0, dim).float() / dim))
|
|
793
|
+
self.register_buffer('inv_freq', inv_freq)
|
|
794
|
+
|
|
795
|
+
self.learned_bias = nn.Parameter(torch.zeros(dim))
|
|
796
|
+
|
|
797
|
+
if bias_uniform_init:
|
|
798
|
+
self.learned_bias.uniform_(-2. * math.pi, 0.)
|
|
799
|
+
|
|
800
|
+
@autocast('cuda', enabled = False)
|
|
801
|
+
def forward(self, t, offset = 0):
|
|
802
|
+
max_pos = t.max() + 1
|
|
803
|
+
|
|
804
|
+
if t.ndim == 1:
|
|
805
|
+
t = rearrange(t, 'n -> 1 n')
|
|
806
|
+
|
|
807
|
+
freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq)
|
|
808
|
+
|
|
809
|
+
bias = self.learned_bias.clamp(-2. * math.pi, 0.)
|
|
810
|
+
|
|
811
|
+
return freqs, bias
|
|
812
|
+
|
|
813
|
+
@autocast('cuda', enabled = False)
|
|
814
|
+
def apply_polar_pos_emb(t, freqs):
|
|
815
|
+
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
|
|
816
|
+
freqs = freqs[:, -seq_len:]
|
|
817
|
+
|
|
818
|
+
t = t.float()
|
|
819
|
+
|
|
820
|
+
t = F.softplus(t)
|
|
821
|
+
out = cat((t * freqs.cos(), t * freqs.sin()), dim = -1)
|
|
822
|
+
|
|
823
|
+
return out.type(orig_dtype)
|
|
824
|
+
|
|
782
825
|
# norms
|
|
783
826
|
|
|
784
827
|
class Scale(Module):
|
|
@@ -1745,6 +1788,7 @@ class Attention(Module):
|
|
|
1745
1788
|
attn_bias = None,
|
|
1746
1789
|
rotary_pos_emb = None,
|
|
1747
1790
|
context_rotary_pos_emb = None,
|
|
1791
|
+
polar_pos_emb = None,
|
|
1748
1792
|
pos = None, # for custom alibi positions
|
|
1749
1793
|
prev_attn = None,
|
|
1750
1794
|
mem = None,
|
|
@@ -1896,6 +1940,11 @@ class Attention(Module):
|
|
|
1896
1940
|
q = cat((q_rest, q), dim = 1)
|
|
1897
1941
|
k = cat((k_rest, k), dim = 1)
|
|
1898
1942
|
|
|
1943
|
+
if exists(polar_pos_emb):
|
|
1944
|
+
freqs, bias = polar_pos_emb
|
|
1945
|
+
q = apply_polar_pos_emb(q, freqs)
|
|
1946
|
+
k = apply_polar_pos_emb(k, freqs + bias)
|
|
1947
|
+
|
|
1899
1948
|
input_mask = context_mask
|
|
1900
1949
|
|
|
1901
1950
|
if not exists(input_mask) and not has_context:
|
|
@@ -2174,6 +2223,8 @@ class AttentionLayers(Module):
|
|
|
2174
2223
|
rotary_xpos_scale_base = 512,
|
|
2175
2224
|
rotary_base_rescale_factor = 1.,
|
|
2176
2225
|
rotate_num_heads = None,
|
|
2226
|
+
polar_pos_emb = False,
|
|
2227
|
+
polar_bias_uniform_init = False,
|
|
2177
2228
|
weight_tie_layers = False,
|
|
2178
2229
|
custom_layers: tuple[str, ...] | None = None,
|
|
2179
2230
|
layers_execute_order: tuple[int, ...] | None = None,
|
|
@@ -2250,14 +2301,13 @@ class AttentionLayers(Module):
|
|
|
2250
2301
|
|
|
2251
2302
|
# LIMe
|
|
2252
2303
|
|
|
2253
|
-
hiddens_counter = 0
|
|
2254
2304
|
self.layer_integrators = ModuleList([])
|
|
2255
2305
|
|
|
2256
2306
|
assert not (qkv_receive_diff_residuals and not (hyper_conn_produce_diff_views or integrate_layers))
|
|
2257
2307
|
|
|
2258
2308
|
# positions related
|
|
2259
2309
|
|
|
2260
|
-
self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
|
|
2310
|
+
self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb or polar_pos_emb))
|
|
2261
2311
|
|
|
2262
2312
|
rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
|
|
2263
2313
|
|
|
@@ -2266,9 +2316,14 @@ class AttentionLayers(Module):
|
|
|
2266
2316
|
if verbose and rotary_emb_dim < 32:
|
|
2267
2317
|
logger.warning('when training language model, rotary embedding dimension should be at least 32')
|
|
2268
2318
|
|
|
2319
|
+
assert at_most_one_of(rotary_pos_emb, polar_pos_emb), f'either rotary positional embedding or polar positional embedding can be turned on'
|
|
2269
2320
|
assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
|
|
2270
2321
|
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
|
|
2271
2322
|
|
|
2323
|
+
# polar positional embedding (PoPE) - https://arxiv.org/abs/2509.10534
|
|
2324
|
+
|
|
2325
|
+
self.polar_pos_emb = PolarEmbedding(dim_head, polar_bias_uniform_init) if polar_pos_emb else None
|
|
2326
|
+
|
|
2272
2327
|
assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), dynamic tanh, or T5 relative positional bias'
|
|
2273
2328
|
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
|
2274
2329
|
|
|
@@ -2626,6 +2681,7 @@ class AttentionLayers(Module):
|
|
|
2626
2681
|
cache_age = 1,
|
|
2627
2682
|
return_hiddens = False,
|
|
2628
2683
|
rotary_pos_emb = None,
|
|
2684
|
+
polar_pos_emb = None,
|
|
2629
2685
|
pos = None,
|
|
2630
2686
|
context_pos = None,
|
|
2631
2687
|
attn_bias = None,
|
|
@@ -2721,6 +2777,15 @@ class AttentionLayers(Module):
|
|
|
2721
2777
|
context_rotary_pos_emb = context_rotary_pos_emb
|
|
2722
2778
|
)
|
|
2723
2779
|
|
|
2780
|
+
# polar positions
|
|
2781
|
+
|
|
2782
|
+
if exists(self.polar_pos_emb):
|
|
2783
|
+
if not exists(polar_pos_emb):
|
|
2784
|
+
if not exists(pos):
|
|
2785
|
+
pos = arange(x.shape[1] + seq_pos_offset, device = x.device)
|
|
2786
|
+
|
|
2787
|
+
polar_pos_emb = self.polar_pos_emb(pos)
|
|
2788
|
+
|
|
2724
2789
|
# assume cached key / values
|
|
2725
2790
|
|
|
2726
2791
|
prev_cache_length = 0
|
|
@@ -2910,7 +2975,7 @@ class AttentionLayers(Module):
|
|
|
2910
2975
|
# forward depending on layer type
|
|
2911
2976
|
|
|
2912
2977
|
if layer_type == 'a':
|
|
2913
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, kv_input_residual = next(self_attn_kv_residuals_iter, None), value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
|
2978
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, polar_pos_emb = polar_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, kv_input_residual = next(self_attn_kv_residuals_iter, None), value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
|
2914
2979
|
elif layer_type == 'c':
|
|
2915
2980
|
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), kv_input_residual = next(cross_attn_kv_residuals_iter, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
|
2916
2981
|
elif layer_type == 'f':
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: x-transformers
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.12.1
|
|
4
4
|
Summary: X-Transformers
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
|
@@ -2630,4 +2630,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
|
2630
2630
|
}
|
|
2631
2631
|
```
|
|
2632
2632
|
|
|
2633
|
+
```bibtex
|
|
2634
|
+
@misc{gopalakrishnan2025decouplingwhatwherepolar,
|
|
2635
|
+
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
|
|
2636
|
+
author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
|
|
2637
|
+
year = {2025},
|
|
2638
|
+
eprint = {2509.10534},
|
|
2639
|
+
archivePrefix = {arXiv},
|
|
2640
|
+
primaryClass = {cs.LG},
|
|
2641
|
+
url = {https://arxiv.org/abs/2509.10534},
|
|
2642
|
+
}
|
|
2643
|
+
```
|
|
2644
|
+
|
|
2633
2645
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
|
@@ -11,10 +11,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
|
|
|
11
11
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
|
12
12
|
x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
|
|
13
13
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
|
14
|
-
x_transformers/x_transformers.py,sha256=
|
|
14
|
+
x_transformers/x_transformers.py,sha256=hW_adsM99N3MGIvsn7zv5w2KRZ3pXgiMsTK333MEIko,132038
|
|
15
15
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
|
16
16
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
|
17
|
-
x_transformers-2.
|
|
18
|
-
x_transformers-2.
|
|
19
|
-
x_transformers-2.
|
|
20
|
-
x_transformers-2.
|
|
17
|
+
x_transformers-2.12.1.dist-info/METADATA,sha256=mlqg7VnbIjDvh_Zp7HXHJu1HLzUWB2SB_5qCPlalGhc,97174
|
|
18
|
+
x_transformers-2.12.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
19
|
+
x_transformers-2.12.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
|
20
|
+
x_transformers-2.12.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|