x-transformers 2.11.24__py3-none-any.whl → 2.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

@@ -779,6 +779,49 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
779
779
 
780
780
  return out.type(orig_dtype)
781
781
 
782
+ class PolarEmbedding(Module):
783
+ """ https://arxiv.org/abs/2509.10534 """
784
+
785
+ def __init__(
786
+ self,
787
+ dim,
788
+ bias_uniform_init = False,
789
+ base = 10000,
790
+ ):
791
+ super().__init__()
792
+ inv_freq = 1. / (base ** (arange(0, dim).float() / dim))
793
+ self.register_buffer('inv_freq', inv_freq)
794
+
795
+ self.learned_bias = nn.Parameter(torch.zeros(dim))
796
+
797
+ if bias_uniform_init:
798
+ self.learned_bias.uniform_(-2. * math.pi, 0.)
799
+
800
+ @autocast('cuda', enabled = False)
801
+ def forward(self, t, offset = 0):
802
+ max_pos = t.max() + 1
803
+
804
+ if t.ndim == 1:
805
+ t = rearrange(t, 'n -> 1 n')
806
+
807
+ freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq)
808
+
809
+ bias = self.learned_bias.clamp(-2. * math.pi, 0.)
810
+
811
+ return freqs, bias
812
+
813
+ @autocast('cuda', enabled = False)
814
+ def apply_polar_pos_emb(t, freqs):
815
+ rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
816
+ freqs = freqs[:, -seq_len:]
817
+
818
+ t = t.float()
819
+
820
+ t = F.softplus(t)
821
+ out = cat((t * freqs.cos(), t * freqs.sin()), dim = -1)
822
+
823
+ return out.type(orig_dtype)
824
+
782
825
  # norms
783
826
 
784
827
  class Scale(Module):
@@ -1745,6 +1788,7 @@ class Attention(Module):
1745
1788
  attn_bias = None,
1746
1789
  rotary_pos_emb = None,
1747
1790
  context_rotary_pos_emb = None,
1791
+ polar_pos_emb = None,
1748
1792
  pos = None, # for custom alibi positions
1749
1793
  prev_attn = None,
1750
1794
  mem = None,
@@ -1896,6 +1940,11 @@ class Attention(Module):
1896
1940
  q = cat((q_rest, q), dim = 1)
1897
1941
  k = cat((k_rest, k), dim = 1)
1898
1942
 
1943
+ if exists(polar_pos_emb):
1944
+ freqs, bias = polar_pos_emb
1945
+ q = apply_polar_pos_emb(q, freqs)
1946
+ k = apply_polar_pos_emb(k, freqs + bias)
1947
+
1899
1948
  input_mask = context_mask
1900
1949
 
1901
1950
  if not exists(input_mask) and not has_context:
@@ -2174,6 +2223,8 @@ class AttentionLayers(Module):
2174
2223
  rotary_xpos_scale_base = 512,
2175
2224
  rotary_base_rescale_factor = 1.,
2176
2225
  rotate_num_heads = None,
2226
+ polar_pos_emb = False,
2227
+ polar_bias_uniform_init = False,
2177
2228
  weight_tie_layers = False,
2178
2229
  custom_layers: tuple[str, ...] | None = None,
2179
2230
  layers_execute_order: tuple[int, ...] | None = None,
@@ -2250,14 +2301,13 @@ class AttentionLayers(Module):
2250
2301
 
2251
2302
  # LIMe
2252
2303
 
2253
- hiddens_counter = 0
2254
2304
  self.layer_integrators = ModuleList([])
2255
2305
 
2256
2306
  assert not (qkv_receive_diff_residuals and not (hyper_conn_produce_diff_views or integrate_layers))
2257
2307
 
2258
2308
  # positions related
2259
2309
 
2260
- self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
2310
+ self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb or polar_pos_emb))
2261
2311
 
2262
2312
  rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
2263
2313
 
@@ -2266,9 +2316,14 @@ class AttentionLayers(Module):
2266
2316
  if verbose and rotary_emb_dim < 32:
2267
2317
  logger.warning('when training language model, rotary embedding dimension should be at least 32')
2268
2318
 
2319
+ assert at_most_one_of(rotary_pos_emb, polar_pos_emb), f'either rotary positional embedding or polar positional embedding can be turned on'
2269
2320
  assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
2270
2321
  self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
2271
2322
 
2323
+ # polar positional embedding (PoPE) - https://arxiv.org/abs/2509.10534
2324
+
2325
+ self.polar_pos_emb = PolarEmbedding(dim_head, polar_bias_uniform_init) if polar_pos_emb else None
2326
+
2272
2327
  assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), dynamic tanh, or T5 relative positional bias'
2273
2328
  assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
2274
2329
 
@@ -2626,6 +2681,7 @@ class AttentionLayers(Module):
2626
2681
  cache_age = 1,
2627
2682
  return_hiddens = False,
2628
2683
  rotary_pos_emb = None,
2684
+ polar_pos_emb = None,
2629
2685
  pos = None,
2630
2686
  context_pos = None,
2631
2687
  attn_bias = None,
@@ -2721,6 +2777,15 @@ class AttentionLayers(Module):
2721
2777
  context_rotary_pos_emb = context_rotary_pos_emb
2722
2778
  )
2723
2779
 
2780
+ # polar positions
2781
+
2782
+ if exists(self.polar_pos_emb):
2783
+ if not exists(polar_pos_emb):
2784
+ if not exists(pos):
2785
+ pos = arange(x.shape[1] + seq_pos_offset, device = x.device)
2786
+
2787
+ polar_pos_emb = self.polar_pos_emb(pos)
2788
+
2724
2789
  # assume cached key / values
2725
2790
 
2726
2791
  prev_cache_length = 0
@@ -2910,7 +2975,7 @@ class AttentionLayers(Module):
2910
2975
  # forward depending on layer type
2911
2976
 
2912
2977
  if layer_type == 'a':
2913
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, kv_input_residual = next(self_attn_kv_residuals_iter, None), value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2978
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, polar_pos_emb = polar_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, kv_input_residual = next(self_attn_kv_residuals_iter, None), value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2914
2979
  elif layer_type == 'c':
2915
2980
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), kv_input_residual = next(cross_attn_kv_residuals_iter, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2916
2981
  elif layer_type == 'f':
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.24
3
+ Version: 2.12.1
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2630,4 +2630,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2630
2630
  }
2631
2631
  ```
2632
2632
 
2633
+ ```bibtex
2634
+ @misc{gopalakrishnan2025decouplingwhatwherepolar,
2635
+ title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
2636
+ author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
2637
+ year = {2025},
2638
+ eprint = {2509.10534},
2639
+ archivePrefix = {arXiv},
2640
+ primaryClass = {cs.LG},
2641
+ url = {https://arxiv.org/abs/2509.10534},
2642
+ }
2643
+ ```
2644
+
2633
2645
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -11,10 +11,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
11
11
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
12
12
  x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
13
13
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
14
- x_transformers/x_transformers.py,sha256=ESupgE2cCteH4_PMWz9tPolAmT0lqBEwITUnT6ZhR8Y,129954
14
+ x_transformers/x_transformers.py,sha256=hW_adsM99N3MGIvsn7zv5w2KRZ3pXgiMsTK333MEIko,132038
15
15
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
16
16
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
17
- x_transformers-2.11.24.dist-info/METADATA,sha256=i318BRxUVm0HCN6Yr63oRFG8ISc04t6-8wT8N6mfFew,96751
18
- x_transformers-2.11.24.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
19
- x_transformers-2.11.24.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
- x_transformers-2.11.24.dist-info/RECORD,,
17
+ x_transformers-2.12.1.dist-info/METADATA,sha256=mlqg7VnbIjDvh_Zp7HXHJu1HLzUWB2SB_5qCPlalGhc,97174
18
+ x_transformers-2.12.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
19
+ x_transformers-2.12.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
+ x_transformers-2.12.1.dist-info/RECORD,,