x-transformers 2.7.4__tar.gz → 2.7.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.7.4 → x_transformers-2.7.6}/PKG-INFO +1 -1
- {x_transformers-2.7.4 → x_transformers-2.7.6}/pyproject.toml +1 -1
- {x_transformers-2.7.4 → x_transformers-2.7.6}/tests/test_x_transformers.py +16 -1
- {x_transformers-2.7.4 → x_transformers-2.7.6}/x_transformers/x_transformers.py +24 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/.github/FUNDING.yml +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/.gitignore +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/LICENSE +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/README.md +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/data/README.md +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/data/enwik8.gz +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/all-attention.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/attention-on-attention.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/deepnorm.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/fcm.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/ffglu.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/flash-attention.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/gate_values.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/gating.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/macaron-1.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/macaron-2.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/memory-transformer.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/normformer.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/pia.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/resi_dual.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/residual_attn.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/rezero.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/rotary.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/sandwich-2.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/sandwich.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/sandwich_norm.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/scalenorm.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/talking-heads.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/topk-attention.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/images/xval.png +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/train_belief_state.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/train_copy.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/train_enwik8.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/train_length_extrapolate.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/train_parity.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/x_transformers/__init__.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/x_transformers/attend.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/x_transformers/continuous.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/x_transformers/dpo.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.7.4 → x_transformers-2.7.6}/x_transformers/xval.py +0 -0
@@ -1315,7 +1315,7 @@ def test_simple_mdlm(
|
|
1315
1315
|
loss = nar(seq)
|
1316
1316
|
loss.loss.backward()
|
1317
1317
|
|
1318
|
-
def
|
1318
|
+
def test_qk_clip_attn():
|
1319
1319
|
from x_transformers import Attention
|
1320
1320
|
|
1321
1321
|
x = torch.randn(1, 1024, 512)
|
@@ -1325,3 +1325,18 @@ def test_qk_clip():
|
|
1325
1325
|
out, intermediates = attn(x, return_intermediates = True)
|
1326
1326
|
|
1327
1327
|
attn.qk_clip_(intermediates, tau = 100)
|
1328
|
+
|
1329
|
+
def test_qk_clip_attn_layers():
|
1330
|
+
from x_transformers import TransformerWrapper, Decoder
|
1331
|
+
|
1332
|
+
model = TransformerWrapper(
|
1333
|
+
num_tokens = 256,
|
1334
|
+
max_seq_len = 1024,
|
1335
|
+
attn_layers = Decoder(dim = 512, depth = 2)
|
1336
|
+
)
|
1337
|
+
|
1338
|
+
seq = torch.randint(0, 256, (1, 1024))
|
1339
|
+
|
1340
|
+
out, intermediates = model(seq, return_intermediates = True)
|
1341
|
+
|
1342
|
+
model.attn_qk_clip_(intermediates)
|
@@ -2462,6 +2462,23 @@ class AttentionLayers(Module):
|
|
2462
2462
|
|
2463
2463
|
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
|
2464
2464
|
|
2465
|
+
def attn_qk_clip_(
|
2466
|
+
self,
|
2467
|
+
intermediates: LayerIntermediates,
|
2468
|
+
tau = 100.
|
2469
|
+
):
|
2470
|
+
# pairs up the attention intermediates with each attention module and does qk clip proposed by kimi team
|
2471
|
+
|
2472
|
+
layer_and_layer_types = (self.layers, self.layer_types)
|
2473
|
+
|
2474
|
+
attn_layers = [layer for (_, layer, _), layer_type in zip(self.layers, self.layer_types) if layer_type in ('a', 'c')]
|
2475
|
+
attn_intermeds = intermediates.attn_intermediates
|
2476
|
+
|
2477
|
+
assert len(attn_layers) == len(attn_intermeds)
|
2478
|
+
|
2479
|
+
for attn_layer, attn_inter in zip(attn_layers, attn_intermeds):
|
2480
|
+
attn_layer.qk_clip_(attn_inter, tau = tau)
|
2481
|
+
|
2465
2482
|
def forward(
|
2466
2483
|
self,
|
2467
2484
|
x,
|
@@ -3192,6 +3209,13 @@ class TransformerWrapper(Module):
|
|
3192
3209
|
if not isinstance(self.pos_emb, always):
|
3193
3210
|
nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
|
3194
3211
|
|
3212
|
+
def attn_qk_clip_(
|
3213
|
+
self,
|
3214
|
+
intermediates: LayerIntermediates,
|
3215
|
+
tau = 100.
|
3216
|
+
):
|
3217
|
+
self.attn_layers.attn_qk_clip_(intermediates, tau = tau)
|
3218
|
+
|
3195
3219
|
def forward(
|
3196
3220
|
self,
|
3197
3221
|
x,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|