x-transformers 2.11.24__tar.gz → 2.12.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.11.24 → x_transformers-2.12.1}/PKG-INFO +13 -1
  2. {x_transformers-2.11.24 → x_transformers-2.12.1}/README.md +12 -0
  3. {x_transformers-2.11.24 → x_transformers-2.12.1}/pyproject.toml +1 -1
  4. {x_transformers-2.11.24 → x_transformers-2.12.1}/tests/test_x_transformers.py +18 -0
  5. {x_transformers-2.11.24 → x_transformers-2.12.1}/train_length_extrapolate.py +55 -25
  6. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/x_transformers.py +68 -3
  7. {x_transformers-2.11.24 → x_transformers-2.12.1}/.github/FUNDING.yml +0 -0
  8. {x_transformers-2.11.24 → x_transformers-2.12.1}/.github/workflows/python-publish.yml +0 -0
  9. {x_transformers-2.11.24 → x_transformers-2.12.1}/.github/workflows/python-test.yaml +0 -0
  10. {x_transformers-2.11.24 → x_transformers-2.12.1}/.gitignore +0 -0
  11. {x_transformers-2.11.24 → x_transformers-2.12.1}/LICENSE +0 -0
  12. {x_transformers-2.11.24 → x_transformers-2.12.1}/data/README.md +0 -0
  13. {x_transformers-2.11.24 → x_transformers-2.12.1}/data/enwik8.gz +0 -0
  14. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/all-attention.png +0 -0
  15. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/deepnorm.png +0 -0
  18. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/fcm.png +0 -0
  24. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/ffglu.png +0 -0
  25. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/flash-attention.png +0 -0
  26. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/gate_values.png +0 -0
  27. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/gating.png +0 -0
  28. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/macaron-1.png +0 -0
  30. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/macaron-2.png +0 -0
  31. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/normformer.png +0 -0
  33. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/pia.png +0 -0
  34. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/resi_dual.png +0 -0
  36. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/residual_attn.png +0 -0
  37. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/rezero.png +0 -0
  38. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/rotary.png +0 -0
  39. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/sandwich.png +0 -0
  41. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/scalenorm.png +0 -0
  43. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/talking-heads.png +0 -0
  44. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/topk-attention.png +0 -0
  45. {x_transformers-2.11.24 → x_transformers-2.12.1}/images/xval.png +0 -0
  46. {x_transformers-2.11.24 → x_transformers-2.12.1}/train_belief_state.py +0 -0
  47. {x_transformers-2.11.24 → x_transformers-2.12.1}/train_copy.py +0 -0
  48. {x_transformers-2.11.24 → x_transformers-2.12.1}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.11.24 → x_transformers-2.12.1}/train_enwik8.py +0 -0
  50. {x_transformers-2.11.24 → x_transformers-2.12.1}/train_free.py +0 -0
  51. {x_transformers-2.11.24 → x_transformers-2.12.1}/train_gpt_vae.py +0 -0
  52. {x_transformers-2.11.24 → x_transformers-2.12.1}/train_parity.py +0 -0
  53. {x_transformers-2.11.24 → x_transformers-2.12.1}/train_with_muon.py +0 -0
  54. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/__init__.py +0 -0
  55. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/attend.py +0 -0
  56. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/autoregressive_wrapper.py +0 -0
  57. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/belief_state_wrapper.py +0 -0
  58. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/continuous.py +0 -0
  59. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/dpo.py +0 -0
  60. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/entropy_based_tokenizer.py +0 -0
  61. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/free_transformer.py +0 -0
  62. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/gpt_vae.py +0 -0
  63. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/multi_input.py +0 -0
  64. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/neo_mlp.py +0 -0
  65. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/nonautoregressive_wrapper.py +0 -0
  66. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/up_wrapper.py +0 -0
  67. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.11.24 → x_transformers-2.12.1}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.24
3
+ Version: 2.12.1
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2630,4 +2630,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2630
2630
  }
2631
2631
  ```
2632
2632
 
2633
+ ```bibtex
2634
+ @misc{gopalakrishnan2025decouplingwhatwherepolar,
2635
+ title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
2636
+ author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
2637
+ year = {2025},
2638
+ eprint = {2509.10534},
2639
+ archivePrefix = {arXiv},
2640
+ primaryClass = {cs.LG},
2641
+ url = {https://arxiv.org/abs/2509.10534},
2642
+ }
2643
+ ```
2644
+
2633
2645
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2581,4 +2581,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2581
2581
  }
2582
2582
  ```
2583
2583
 
2584
+ ```bibtex
2585
+ @misc{gopalakrishnan2025decouplingwhatwherepolar,
2586
+ title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
2587
+ author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
2588
+ year = {2025},
2589
+ eprint = {2509.10534},
2590
+ archivePrefix = {arXiv},
2591
+ primaryClass = {cs.LG},
2592
+ url = {https://arxiv.org/abs/2509.10534},
2593
+ }
2594
+ ```
2595
+
2584
2596
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.11.24"
3
+ version = "2.12.1"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1508,3 +1508,21 @@ def test_derf():
1508
1508
  x = torch.randint(0, 256, (1, 10))
1509
1509
 
1510
1510
  logits = model(x)
1511
+
1512
+ def test_pope():
1513
+ from x_transformers import TransformerWrapper, Decoder
1514
+
1515
+ model = TransformerWrapper(
1516
+ num_tokens = 256,
1517
+ max_seq_len = 1024,
1518
+ attn_layers = Decoder(
1519
+ dim = 512,
1520
+ depth = 6,
1521
+ heads = 8,
1522
+ polar_pos_emb = True,
1523
+ )
1524
+ )
1525
+
1526
+ x = torch.randint(0, 256, (1, 10))
1527
+
1528
+ logits = model(x)
@@ -1,3 +1,11 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "accelerate",
4
+ # "tqdm",
5
+ # "x-transformers>=2.12.0",
6
+ # ]
7
+ # ///
8
+
1
9
  from x_transformers import TransformerWrapper, Decoder
2
10
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
3
11
 
@@ -10,6 +18,8 @@ import torch.optim as optim
10
18
  from torch.nn import functional as F
11
19
  from torch.utils.data import DataLoader, Dataset
12
20
 
21
+ from accelerate import Accelerator
22
+
13
23
  # constants
14
24
 
15
25
  NUM_BATCHES = int(1e5)
@@ -20,7 +30,7 @@ GENERATE_EVERY = 500
20
30
  GENERATE_LENGTH = 256
21
31
  SEQ_LEN = 256
22
32
 
23
- VALIDATE_EVERY = 100
33
+ VALIDATE_EVERY = 250
24
34
  VALIDATE_SEQ_LENS = (256, 512, 1024, 2048, 4096)
25
35
 
26
36
  # helpers
@@ -36,6 +46,10 @@ def decode_token(token):
36
46
  def decode_tokens(tokens):
37
47
  return ''.join(list(map(decode_token, tokens)))
38
48
 
49
+ # accelerator
50
+
51
+ accelerator = Accelerator()
52
+
39
53
  # instantiate GPT-like decoder model
40
54
 
41
55
  model = TransformerWrapper(
@@ -46,12 +60,13 @@ model = TransformerWrapper(
46
60
  dim = 512,
47
61
  depth = 6,
48
62
  heads = 8,
49
- dynamic_pos_bias = True,
63
+ polar_pos_emb = True,
64
+ rotary_pos_emb = False,
65
+ dynamic_pos_bias = False
50
66
  )
51
67
  )
52
68
 
53
69
  model = AutoregressiveWrapper(model)
54
- model.cuda()
55
70
 
56
71
  # prepare enwik8 data
57
72
 
@@ -69,69 +84,84 @@ class TextSamplerDataset(Dataset):
69
84
  def __getitem__(self, index):
70
85
  rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
71
86
  full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
72
- return full_seq.cuda()
87
+ return full_seq
73
88
 
74
89
  def __len__(self):
75
90
  return self.data.size(0) // self.seq_len
76
91
 
77
92
  train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
78
- train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
93
+ train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True)
79
94
 
80
95
  val_dataset_generate = TextSamplerDataset(data_val, SEQ_LEN)
81
96
 
97
+ # optimizer
98
+
99
+ optim = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
100
+
101
+ # prepare
102
+
103
+ model, optim, train_loader = accelerator.prepare(model, optim, train_loader)
104
+
105
+ train_loader = cycle(train_loader)
106
+
82
107
  # validation loaders with different sequence lengths
83
108
 
84
109
  val_loaders = dict()
85
110
 
86
111
  for valid_seq_len in VALIDATE_SEQ_LENS:
87
112
  val_dataset = TextSamplerDataset(data_val, valid_seq_len)
88
- val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
113
+ val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True)
114
+ val_loader = cycle(val_loader)
89
115
 
90
116
  val_loaders[valid_seq_len] = val_loader
91
117
 
92
- # optimizer
93
-
94
- optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
95
-
96
118
  # training
97
119
 
98
120
  for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
99
121
  model.train()
100
122
 
101
- for __ in range(GRADIENT_ACCUMULATE_EVERY):
102
- loss = model(next(train_loader))
103
- (loss / GRADIENT_ACCUMULATE_EVERY).backward()
123
+ for _ in range(GRADIENT_ACCUMULATE_EVERY):
124
+ data = next(train_loader)
125
+ loss = model(data)
126
+ accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY)
127
+
128
+ if accelerator.sync_gradients:
129
+ accelerator.clip_grad_norm_(model.parameters(), 0.5)
104
130
 
105
- print(f'training loss: {loss.item()}')
131
+ optim.step()
132
+ optim.zero_grad()
106
133
 
107
- torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
108
- optim.step()
109
- optim.zero_grad()
134
+ if i % 10 == 0:
135
+ accelerator.print(f'training loss: {loss.item()}')
110
136
 
111
137
  if i % VALIDATE_EVERY == 0:
112
- print(f'validation losses:\n')
138
+ accelerator.print(f'validation losses:\n')
113
139
 
114
140
  model.eval()
115
- with torch.no_grad():
141
+ with torch.inference_mode():
116
142
  for valid_seq_len in VALIDATE_SEQ_LENS:
117
143
  val_loader = val_loaders[valid_seq_len]
118
144
 
119
- loss = model(next(val_loader))
120
- print(f'[{valid_seq_len}]:\t {loss.item()}')
145
+ val_data = next(val_loader).to(accelerator.device)
146
+ loss = model(val_data)
147
+ accelerator.print(f'[{valid_seq_len}]:\t {loss.item()}')
121
148
 
122
- print('\n')
149
+ accelerator.print('\n')
123
150
 
124
151
  if i % GENERATE_EVERY == 0:
125
152
  model.eval()
153
+ unwrapped_model = accelerator.unwrap_model(model)
154
+
126
155
  inp = random.choice(val_dataset_generate)[:-1]
156
+ inp = inp.to(accelerator.device)
127
157
  prime = decode_tokens(inp)
128
- print(f'%s \n\n %s', (prime, '*' * 100))
158
+ accelerator.print(f'{prime} \n\n {"*" * 100}')
129
159
 
130
- sample = model.generate(
160
+ sample = unwrapped_model.generate(
131
161
  prompts = inp,
132
162
  seq_len = GENERATE_LENGTH,
133
163
  cache_kv = True
134
164
  )
135
165
 
136
166
  output_str = decode_tokens(sample)
137
- print(f'{output_str}\n\n')
167
+ accelerator.print(f'{output_str}\n\n')
@@ -779,6 +779,49 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
779
779
 
780
780
  return out.type(orig_dtype)
781
781
 
782
+ class PolarEmbedding(Module):
783
+ """ https://arxiv.org/abs/2509.10534 """
784
+
785
+ def __init__(
786
+ self,
787
+ dim,
788
+ bias_uniform_init = False,
789
+ base = 10000,
790
+ ):
791
+ super().__init__()
792
+ inv_freq = 1. / (base ** (arange(0, dim).float() / dim))
793
+ self.register_buffer('inv_freq', inv_freq)
794
+
795
+ self.learned_bias = nn.Parameter(torch.zeros(dim))
796
+
797
+ if bias_uniform_init:
798
+ self.learned_bias.uniform_(-2. * math.pi, 0.)
799
+
800
+ @autocast('cuda', enabled = False)
801
+ def forward(self, t, offset = 0):
802
+ max_pos = t.max() + 1
803
+
804
+ if t.ndim == 1:
805
+ t = rearrange(t, 'n -> 1 n')
806
+
807
+ freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq)
808
+
809
+ bias = self.learned_bias.clamp(-2. * math.pi, 0.)
810
+
811
+ return freqs, bias
812
+
813
+ @autocast('cuda', enabled = False)
814
+ def apply_polar_pos_emb(t, freqs):
815
+ rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
816
+ freqs = freqs[:, -seq_len:]
817
+
818
+ t = t.float()
819
+
820
+ t = F.softplus(t)
821
+ out = cat((t * freqs.cos(), t * freqs.sin()), dim = -1)
822
+
823
+ return out.type(orig_dtype)
824
+
782
825
  # norms
783
826
 
784
827
  class Scale(Module):
@@ -1745,6 +1788,7 @@ class Attention(Module):
1745
1788
  attn_bias = None,
1746
1789
  rotary_pos_emb = None,
1747
1790
  context_rotary_pos_emb = None,
1791
+ polar_pos_emb = None,
1748
1792
  pos = None, # for custom alibi positions
1749
1793
  prev_attn = None,
1750
1794
  mem = None,
@@ -1896,6 +1940,11 @@ class Attention(Module):
1896
1940
  q = cat((q_rest, q), dim = 1)
1897
1941
  k = cat((k_rest, k), dim = 1)
1898
1942
 
1943
+ if exists(polar_pos_emb):
1944
+ freqs, bias = polar_pos_emb
1945
+ q = apply_polar_pos_emb(q, freqs)
1946
+ k = apply_polar_pos_emb(k, freqs + bias)
1947
+
1899
1948
  input_mask = context_mask
1900
1949
 
1901
1950
  if not exists(input_mask) and not has_context:
@@ -2174,6 +2223,8 @@ class AttentionLayers(Module):
2174
2223
  rotary_xpos_scale_base = 512,
2175
2224
  rotary_base_rescale_factor = 1.,
2176
2225
  rotate_num_heads = None,
2226
+ polar_pos_emb = False,
2227
+ polar_bias_uniform_init = False,
2177
2228
  weight_tie_layers = False,
2178
2229
  custom_layers: tuple[str, ...] | None = None,
2179
2230
  layers_execute_order: tuple[int, ...] | None = None,
@@ -2250,14 +2301,13 @@ class AttentionLayers(Module):
2250
2301
 
2251
2302
  # LIMe
2252
2303
 
2253
- hiddens_counter = 0
2254
2304
  self.layer_integrators = ModuleList([])
2255
2305
 
2256
2306
  assert not (qkv_receive_diff_residuals and not (hyper_conn_produce_diff_views or integrate_layers))
2257
2307
 
2258
2308
  # positions related
2259
2309
 
2260
- self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
2310
+ self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb or polar_pos_emb))
2261
2311
 
2262
2312
  rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
2263
2313
 
@@ -2266,9 +2316,14 @@ class AttentionLayers(Module):
2266
2316
  if verbose and rotary_emb_dim < 32:
2267
2317
  logger.warning('when training language model, rotary embedding dimension should be at least 32')
2268
2318
 
2319
+ assert at_most_one_of(rotary_pos_emb, polar_pos_emb), f'either rotary positional embedding or polar positional embedding can be turned on'
2269
2320
  assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
2270
2321
  self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
2271
2322
 
2323
+ # polar positional embedding (PoPE) - https://arxiv.org/abs/2509.10534
2324
+
2325
+ self.polar_pos_emb = PolarEmbedding(dim_head, polar_bias_uniform_init) if polar_pos_emb else None
2326
+
2272
2327
  assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), dynamic tanh, or T5 relative positional bias'
2273
2328
  assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
2274
2329
 
@@ -2626,6 +2681,7 @@ class AttentionLayers(Module):
2626
2681
  cache_age = 1,
2627
2682
  return_hiddens = False,
2628
2683
  rotary_pos_emb = None,
2684
+ polar_pos_emb = None,
2629
2685
  pos = None,
2630
2686
  context_pos = None,
2631
2687
  attn_bias = None,
@@ -2721,6 +2777,15 @@ class AttentionLayers(Module):
2721
2777
  context_rotary_pos_emb = context_rotary_pos_emb
2722
2778
  )
2723
2779
 
2780
+ # polar positions
2781
+
2782
+ if exists(self.polar_pos_emb):
2783
+ if not exists(polar_pos_emb):
2784
+ if not exists(pos):
2785
+ pos = arange(x.shape[1] + seq_pos_offset, device = x.device)
2786
+
2787
+ polar_pos_emb = self.polar_pos_emb(pos)
2788
+
2724
2789
  # assume cached key / values
2725
2790
 
2726
2791
  prev_cache_length = 0
@@ -2910,7 +2975,7 @@ class AttentionLayers(Module):
2910
2975
  # forward depending on layer type
2911
2976
 
2912
2977
  if layer_type == 'a':
2913
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, kv_input_residual = next(self_attn_kv_residuals_iter, None), value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2978
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, polar_pos_emb = polar_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, kv_input_residual = next(self_attn_kv_residuals_iter, None), value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2914
2979
  elif layer_type == 'c':
2915
2980
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), kv_input_residual = next(cross_attn_kv_residuals_iter, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2916
2981
  elif layer_type == 'f':