x-transformers 2.11.23__tar.gz → 2.12.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- {x_transformers-2.11.23 → x_transformers-2.12.0}/PKG-INFO +25 -1
- {x_transformers-2.11.23 → x_transformers-2.12.0}/README.md +24 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/pyproject.toml +1 -1
- {x_transformers-2.11.23 → x_transformers-2.12.0}/tests/test_x_transformers.py +38 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/x_transformers.py +96 -2
- {x_transformers-2.11.23 → x_transformers-2.12.0}/.github/FUNDING.yml +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/.gitignore +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/LICENSE +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/data/README.md +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/data/enwik8.gz +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/all-attention.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/attention-on-attention.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/deepnorm.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/fcm.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/ffglu.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/flash-attention.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/gate_values.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/gating.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/macaron-1.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/macaron-2.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/memory-transformer.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/normformer.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/pia.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/resi_dual.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/residual_attn.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/rezero.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/rotary.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/sandwich-2.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/sandwich.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/sandwich_norm.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/scalenorm.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/talking-heads.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/topk-attention.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/images/xval.png +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/train_belief_state.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/train_copy.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/train_enwik8.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/train_free.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/train_gpt_vae.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/train_length_extrapolate.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/train_parity.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/train_with_muon.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/__init__.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/attend.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/continuous.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/dpo.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/free_transformer.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/gpt_vae.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/xval.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: x-transformers
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.12.0
|
|
4
4
|
Summary: X-Transformers
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
|
@@ -2618,4 +2618,28 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
|
2618
2618
|
}
|
|
2619
2619
|
```
|
|
2620
2620
|
|
|
2621
|
+
```bibtex
|
|
2622
|
+
@misc{chen2025strongernormalizationfreetransformers,
|
|
2623
|
+
title = {Stronger Normalization-Free Transformers},
|
|
2624
|
+
author = {Mingzhi Chen and Taiming Lu and Jiachen Zhu and Mingjie Sun and Zhuang Liu},
|
|
2625
|
+
year = {2025},
|
|
2626
|
+
eprint = {2512.10938},
|
|
2627
|
+
archivePrefix = {arXiv},
|
|
2628
|
+
primaryClass = {cs.LG},
|
|
2629
|
+
url = {https://arxiv.org/abs/2512.10938},
|
|
2630
|
+
}
|
|
2631
|
+
```
|
|
2632
|
+
|
|
2633
|
+
```bibtex
|
|
2634
|
+
@misc{gopalakrishnan2025decouplingwhatwherepolar,
|
|
2635
|
+
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
|
|
2636
|
+
author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
|
|
2637
|
+
year = {2025},
|
|
2638
|
+
eprint = {2509.10534},
|
|
2639
|
+
archivePrefix = {arXiv},
|
|
2640
|
+
primaryClass = {cs.LG},
|
|
2641
|
+
url = {https://arxiv.org/abs/2509.10534},
|
|
2642
|
+
}
|
|
2643
|
+
```
|
|
2644
|
+
|
|
2621
2645
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
|
@@ -2569,4 +2569,28 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
|
2569
2569
|
}
|
|
2570
2570
|
```
|
|
2571
2571
|
|
|
2572
|
+
```bibtex
|
|
2573
|
+
@misc{chen2025strongernormalizationfreetransformers,
|
|
2574
|
+
title = {Stronger Normalization-Free Transformers},
|
|
2575
|
+
author = {Mingzhi Chen and Taiming Lu and Jiachen Zhu and Mingjie Sun and Zhuang Liu},
|
|
2576
|
+
year = {2025},
|
|
2577
|
+
eprint = {2512.10938},
|
|
2578
|
+
archivePrefix = {arXiv},
|
|
2579
|
+
primaryClass = {cs.LG},
|
|
2580
|
+
url = {https://arxiv.org/abs/2512.10938},
|
|
2581
|
+
}
|
|
2582
|
+
```
|
|
2583
|
+
|
|
2584
|
+
```bibtex
|
|
2585
|
+
@misc{gopalakrishnan2025decouplingwhatwherepolar,
|
|
2586
|
+
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
|
|
2587
|
+
author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
|
|
2588
|
+
year = {2025},
|
|
2589
|
+
eprint = {2509.10534},
|
|
2590
|
+
archivePrefix = {arXiv},
|
|
2591
|
+
primaryClass = {cs.LG},
|
|
2592
|
+
url = {https://arxiv.org/abs/2509.10534},
|
|
2593
|
+
}
|
|
2594
|
+
```
|
|
2595
|
+
|
|
2572
2596
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
|
@@ -1488,3 +1488,41 @@ def test_belief_attn(
|
|
|
1488
1488
|
x = torch.randint(0, 256, (1, 10))
|
|
1489
1489
|
|
|
1490
1490
|
logits = model(x)
|
|
1491
|
+
|
|
1492
|
+
def test_derf():
|
|
1493
|
+
from x_transformers import TransformerWrapper, Decoder
|
|
1494
|
+
|
|
1495
|
+
model = TransformerWrapper(
|
|
1496
|
+
num_tokens = 256,
|
|
1497
|
+
max_seq_len = 1024,
|
|
1498
|
+
attn_layers = Decoder(
|
|
1499
|
+
dim = 512,
|
|
1500
|
+
depth = 6,
|
|
1501
|
+
heads = 8,
|
|
1502
|
+
attn_kv_heads = 4,
|
|
1503
|
+
rotary_pos_emb = True,
|
|
1504
|
+
use_derf = True
|
|
1505
|
+
)
|
|
1506
|
+
)
|
|
1507
|
+
|
|
1508
|
+
x = torch.randint(0, 256, (1, 10))
|
|
1509
|
+
|
|
1510
|
+
logits = model(x)
|
|
1511
|
+
|
|
1512
|
+
def test_pope():
|
|
1513
|
+
from x_transformers import TransformerWrapper, Decoder
|
|
1514
|
+
|
|
1515
|
+
model = TransformerWrapper(
|
|
1516
|
+
num_tokens = 256,
|
|
1517
|
+
max_seq_len = 1024,
|
|
1518
|
+
attn_layers = Decoder(
|
|
1519
|
+
dim = 512,
|
|
1520
|
+
depth = 6,
|
|
1521
|
+
heads = 8,
|
|
1522
|
+
polar_pos_emb = True,
|
|
1523
|
+
)
|
|
1524
|
+
)
|
|
1525
|
+
|
|
1526
|
+
x = torch.randint(0, 256, (1, 10))
|
|
1527
|
+
|
|
1528
|
+
logits = model(x)
|
|
@@ -779,6 +779,49 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
|
|
|
779
779
|
|
|
780
780
|
return out.type(orig_dtype)
|
|
781
781
|
|
|
782
|
+
class PolarEmbedding(Module):
|
|
783
|
+
""" https://arxiv.org/abs/2509.10534 """
|
|
784
|
+
|
|
785
|
+
def __init__(
|
|
786
|
+
self,
|
|
787
|
+
dim,
|
|
788
|
+
bias_uniform_init = False,
|
|
789
|
+
base = 10000,
|
|
790
|
+
):
|
|
791
|
+
super().__init__()
|
|
792
|
+
inv_freq = 1. / (base ** (arange(0, dim).float() / dim))
|
|
793
|
+
self.register_buffer('inv_freq', inv_freq)
|
|
794
|
+
|
|
795
|
+
self.learned_bias = nn.Parameter(torch.zeros(dim))
|
|
796
|
+
|
|
797
|
+
if bias_uniform_init:
|
|
798
|
+
self.learned_bias.uniform_(-2. * math.pi, 0.)
|
|
799
|
+
|
|
800
|
+
@autocast('cuda', enabled = False)
|
|
801
|
+
def forward(self, t, offset = 0):
|
|
802
|
+
max_pos = t.max() + 1
|
|
803
|
+
|
|
804
|
+
if t.ndim == 1:
|
|
805
|
+
t = rearrange(t, 'n -> 1 n')
|
|
806
|
+
|
|
807
|
+
freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq)
|
|
808
|
+
|
|
809
|
+
bias = self.learned_bias.clamp(-2. * math.pi, 0.)
|
|
810
|
+
|
|
811
|
+
return freqs, bias
|
|
812
|
+
|
|
813
|
+
@autocast('cuda', enabled = False)
|
|
814
|
+
def apply_polar_pos_emb(t, freqs):
|
|
815
|
+
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
|
|
816
|
+
freqs = freqs[:, -seq_len:]
|
|
817
|
+
|
|
818
|
+
t = t.float()
|
|
819
|
+
|
|
820
|
+
t = F.softplus(t)
|
|
821
|
+
out = cat((t * freqs.cos(), t * freqs.sin()), dim = -1)
|
|
822
|
+
|
|
823
|
+
return out.type(orig_dtype)
|
|
824
|
+
|
|
782
825
|
# norms
|
|
783
826
|
|
|
784
827
|
class Scale(Module):
|
|
@@ -941,6 +984,31 @@ class DynamicTanh(Module):
|
|
|
941
984
|
gamma = self.gamma + self.gamma_offset
|
|
942
985
|
return (x * pre_tanh_scale).tanh() * gamma + self.beta
|
|
943
986
|
|
|
987
|
+
class Derf(Module):
|
|
988
|
+
""" https://arxiv.org/abs/2512.10938 """
|
|
989
|
+
def __init__(
|
|
990
|
+
self,
|
|
991
|
+
dim,
|
|
992
|
+
init_alpha = 0.5,
|
|
993
|
+
init_bias = 0.,
|
|
994
|
+
unit_offset = False
|
|
995
|
+
):
|
|
996
|
+
super().__init__()
|
|
997
|
+
scale_offset = 1. if unit_offset else 0.
|
|
998
|
+
|
|
999
|
+
self.alpha = nn.Parameter(tensor(init_alpha) - scale_offset)
|
|
1000
|
+
self.s = nn.Parameter(tensor(init_bias))
|
|
1001
|
+
|
|
1002
|
+
self.gamma = nn.Parameter(torch.ones(dim) - scale_offset)
|
|
1003
|
+
self.beta = nn.Parameter(torch.zeros(dim))
|
|
1004
|
+
|
|
1005
|
+
self.scale_offset = scale_offset
|
|
1006
|
+
|
|
1007
|
+
def forward(self, x):
|
|
1008
|
+
x = x * (self.alpha + self.scale_offset) + self.s
|
|
1009
|
+
activated = torch.erf(x)
|
|
1010
|
+
return activated * (self.gamma + self.scale_offset) + self.beta
|
|
1011
|
+
|
|
944
1012
|
# residual and residual gates
|
|
945
1013
|
|
|
946
1014
|
class Residual(Module):
|
|
@@ -1720,6 +1788,7 @@ class Attention(Module):
|
|
|
1720
1788
|
attn_bias = None,
|
|
1721
1789
|
rotary_pos_emb = None,
|
|
1722
1790
|
context_rotary_pos_emb = None,
|
|
1791
|
+
polar_pos_emb = None,
|
|
1723
1792
|
pos = None, # for custom alibi positions
|
|
1724
1793
|
prev_attn = None,
|
|
1725
1794
|
mem = None,
|
|
@@ -1871,6 +1940,11 @@ class Attention(Module):
|
|
|
1871
1940
|
q = cat((q_rest, q), dim = 1)
|
|
1872
1941
|
k = cat((k_rest, k), dim = 1)
|
|
1873
1942
|
|
|
1943
|
+
if exists(polar_pos_emb):
|
|
1944
|
+
freqs, bias = polar_pos_emb
|
|
1945
|
+
q = apply_polar_pos_emb(q, freqs)
|
|
1946
|
+
k = apply_polar_pos_emb(k, freqs + bias)
|
|
1947
|
+
|
|
1874
1948
|
input_mask = context_mask
|
|
1875
1949
|
|
|
1876
1950
|
if not exists(input_mask) and not has_context:
|
|
@@ -2123,6 +2197,7 @@ class AttentionLayers(Module):
|
|
|
2123
2197
|
use_scalenorm = False,
|
|
2124
2198
|
use_rmsnorm = False,
|
|
2125
2199
|
use_dynamic_tanh = False,
|
|
2200
|
+
use_derf = False,
|
|
2126
2201
|
dynamic_tanh_init_alpha = 1.,
|
|
2127
2202
|
use_simple_rmsnorm = False,
|
|
2128
2203
|
use_adaptive_layernorm = False,
|
|
@@ -2148,6 +2223,8 @@ class AttentionLayers(Module):
|
|
|
2148
2223
|
rotary_xpos_scale_base = 512,
|
|
2149
2224
|
rotary_base_rescale_factor = 1.,
|
|
2150
2225
|
rotate_num_heads = None,
|
|
2226
|
+
polar_pos_emb = False,
|
|
2227
|
+
polar_bias_uniform_init = False,
|
|
2151
2228
|
weight_tie_layers = False,
|
|
2152
2229
|
custom_layers: tuple[str, ...] | None = None,
|
|
2153
2230
|
layers_execute_order: tuple[int, ...] | None = None,
|
|
@@ -2240,9 +2317,14 @@ class AttentionLayers(Module):
|
|
|
2240
2317
|
if verbose and rotary_emb_dim < 32:
|
|
2241
2318
|
logger.warning('when training language model, rotary embedding dimension should be at least 32')
|
|
2242
2319
|
|
|
2320
|
+
assert at_most_one_of(rotary_pos_emb, polar_pos_emb), f'either rotary positional embedding or polar positional embedding can be turned on'
|
|
2243
2321
|
assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
|
|
2244
2322
|
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
|
|
2245
2323
|
|
|
2324
|
+
# polar positional embedding (PoPE) - https://arxiv.org/abs/2509.10534
|
|
2325
|
+
|
|
2326
|
+
self.polar_pos_emb = PolarEmbedding(dim_head, polar_bias_uniform_init) if polar_pos_emb else None
|
|
2327
|
+
|
|
2246
2328
|
assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), dynamic tanh, or T5 relative positional bias'
|
|
2247
2329
|
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
|
2248
2330
|
|
|
@@ -2277,7 +2359,7 @@ class AttentionLayers(Module):
|
|
|
2277
2359
|
|
|
2278
2360
|
# determine norm
|
|
2279
2361
|
|
|
2280
|
-
assert at_most_one_of(use_scalenorm, use_rmsnorm, use_dynamic_tanh, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
|
|
2362
|
+
assert at_most_one_of(use_scalenorm, use_rmsnorm, use_dynamic_tanh, use_derf, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
|
|
2281
2363
|
|
|
2282
2364
|
norm_need_condition = False
|
|
2283
2365
|
dim_condition = default(dim_condition, dim)
|
|
@@ -2295,6 +2377,8 @@ class AttentionLayers(Module):
|
|
|
2295
2377
|
elif use_dynamic_tanh:
|
|
2296
2378
|
assert pre_norm, 'dynamic tanh norm only tested for pre-norm'
|
|
2297
2379
|
norm_class = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
|
|
2380
|
+
elif use_derf:
|
|
2381
|
+
norm_class = Derf
|
|
2298
2382
|
elif use_adaptive_layernorm:
|
|
2299
2383
|
norm_need_condition = True
|
|
2300
2384
|
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
|
|
@@ -2598,6 +2682,7 @@ class AttentionLayers(Module):
|
|
|
2598
2682
|
cache_age = 1,
|
|
2599
2683
|
return_hiddens = False,
|
|
2600
2684
|
rotary_pos_emb = None,
|
|
2685
|
+
polar_pos_emb = None,
|
|
2601
2686
|
pos = None,
|
|
2602
2687
|
context_pos = None,
|
|
2603
2688
|
attn_bias = None,
|
|
@@ -2693,6 +2778,15 @@ class AttentionLayers(Module):
|
|
|
2693
2778
|
context_rotary_pos_emb = context_rotary_pos_emb
|
|
2694
2779
|
)
|
|
2695
2780
|
|
|
2781
|
+
# polar positions
|
|
2782
|
+
|
|
2783
|
+
if exists(self.polar_pos_emb):
|
|
2784
|
+
if not exists(polar_pos_emb):
|
|
2785
|
+
if not exists(pos):
|
|
2786
|
+
pos = arange(x.shape[1] + seq_pos_offset, device = x.device)
|
|
2787
|
+
|
|
2788
|
+
polar_pos_emb = self.polar_pos_emb(pos)
|
|
2789
|
+
|
|
2696
2790
|
# assume cached key / values
|
|
2697
2791
|
|
|
2698
2792
|
prev_cache_length = 0
|
|
@@ -2882,7 +2976,7 @@ class AttentionLayers(Module):
|
|
|
2882
2976
|
# forward depending on layer type
|
|
2883
2977
|
|
|
2884
2978
|
if layer_type == 'a':
|
|
2885
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, kv_input_residual = next(self_attn_kv_residuals_iter, None), value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
|
2979
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, polar_pos_emb = polar_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, kv_input_residual = next(self_attn_kv_residuals_iter, None), value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
|
2886
2980
|
elif layer_type == 'c':
|
|
2887
2981
|
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), kv_input_residual = next(cross_attn_kv_residuals_iter, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
|
2888
2982
|
elif layer_type == 'f':
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/nonautoregressive_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
{x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|