x-transformers 1.23.0__tar.gz → 1.23.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x-transformers-1.23.0/x_transformers.egg-info → x-transformers-1.23.2}/PKG-INFO +1 -1
- {x-transformers-1.23.0 → x-transformers-1.23.2}/README.md +11 -10
- {x-transformers-1.23.0 → x-transformers-1.23.2}/setup.py +1 -1
- {x-transformers-1.23.0 → x-transformers-1.23.2}/x_transformers/attend.py +0 -76
- {x-transformers-1.23.0 → x-transformers-1.23.2}/x_transformers/x_transformers.py +16 -8
- {x-transformers-1.23.0 → x-transformers-1.23.2/x_transformers.egg-info}/PKG-INFO +1 -1
- {x-transformers-1.23.0 → x-transformers-1.23.2}/LICENSE +0 -0
- {x-transformers-1.23.0 → x-transformers-1.23.2}/setup.cfg +0 -0
- {x-transformers-1.23.0 → x-transformers-1.23.2}/x_transformers/__init__.py +0 -0
- {x-transformers-1.23.0 → x-transformers-1.23.2}/x_transformers/autoregressive_wrapper.py +0 -0
- {x-transformers-1.23.0 → x-transformers-1.23.2}/x_transformers/continuous_autoregressive_wrapper.py +0 -0
- {x-transformers-1.23.0 → x-transformers-1.23.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x-transformers-1.23.0 → x-transformers-1.23.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x-transformers-1.23.0 → x-transformers-1.23.2}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x-transformers-1.23.0 → x-transformers-1.23.2}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x-transformers-1.23.0 → x-transformers-1.23.2}/x_transformers.egg-info/requires.txt +0 -0
- {x-transformers-1.23.0 → x-transformers-1.23.2}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1932,16 +1932,6 @@ generated = model.generate(start_emb, 17) # (17, 777)
|
|
1932
1932
|
}
|
1933
1933
|
```
|
1934
1934
|
|
1935
|
-
```bibtex
|
1936
|
-
@article{Liu2023EfficientViTME,
|
1937
|
-
title = {EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention},
|
1938
|
-
author = {Xinyu Liu and Houwen Peng and Ningxin Zheng and Yuqing Yang and Han Hu and Yixuan Yuan},
|
1939
|
-
journal = {ArXiv},
|
1940
|
-
year = {2023},
|
1941
|
-
volume = {abs/2305.07027}
|
1942
|
-
}
|
1943
|
-
```
|
1944
|
-
|
1945
1935
|
```bibtex
|
1946
1936
|
@article{Kazemnejad2023TheIO,
|
1947
1937
|
title = {The Impact of Positional Encoding on Length Generalization in Transformers},
|
@@ -2007,4 +1997,15 @@ generated = model.generate(start_emb, 17) # (17, 777)
|
|
2007
1997
|
}
|
2008
1998
|
```
|
2009
1999
|
|
2000
|
+
```bibtex
|
2001
|
+
@article{Bondarenko2023QuantizableTR,
|
2002
|
+
title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
|
2003
|
+
author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
|
2004
|
+
journal = {ArXiv},
|
2005
|
+
year = {2023},
|
2006
|
+
volume = {abs/2306.12929},
|
2007
|
+
url = {https://api.semanticscholar.org/CorpusID:259224568}
|
2008
|
+
}
|
2009
|
+
```
|
2010
|
+
|
2010
2011
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -346,79 +346,3 @@ class Attend(nn.Module):
|
|
346
346
|
)
|
347
347
|
|
348
348
|
return out, intermediates
|
349
|
-
|
350
|
-
# cascading heads logic
|
351
|
-
|
352
|
-
def to_single_heads(t, dim = 1):
|
353
|
-
heads = t.unbind(dim = dim)
|
354
|
-
return tuple(head.unsqueeze(dim) for head in heads)
|
355
|
-
|
356
|
-
class CascadingHeads(nn.Module):
|
357
|
-
def __init__(self, attend: Attend):
|
358
|
-
super().__init__()
|
359
|
-
self.attend = attend
|
360
|
-
|
361
|
-
def forward(
|
362
|
-
self,
|
363
|
-
q, k, v,
|
364
|
-
mask = None,
|
365
|
-
attn_bias = None,
|
366
|
-
prev_attn = None
|
367
|
-
):
|
368
|
-
assert q.shape[-1] == v.shape[-1], 'cascading heads can only be done if query / key and value head dimensions are the same'
|
369
|
-
|
370
|
-
# split inputs into per-head inputs
|
371
|
-
|
372
|
-
heads = q.shape[1]
|
373
|
-
|
374
|
-
queries = to_single_heads(q)
|
375
|
-
keys = to_single_heads(k) if k.ndim == 4 else ((k,) * heads)
|
376
|
-
values = to_single_heads(v) if v.ndim == 4 else ((v,) * heads)
|
377
|
-
|
378
|
-
mask = (mask,) * heads
|
379
|
-
|
380
|
-
attn_bias = to_single_heads(attn_bias, dim = 0) if exists(attn_bias) else ((None,) * heads)
|
381
|
-
prev_attn = to_single_heads(prev_attn) if exists(prev_attn) else ((None,) * heads)
|
382
|
-
|
383
|
-
# now loop through each head, without output of previous head summed with the next head
|
384
|
-
# thus cascading
|
385
|
-
|
386
|
-
all_outs = []
|
387
|
-
all_intermediates = []
|
388
|
-
|
389
|
-
prev_head_out = None
|
390
|
-
|
391
|
-
for h_q, h_k, h_v, h_mask, h_attn_bias, h_prev_attn in zip(queries, keys, values, mask, attn_bias, prev_attn):
|
392
|
-
|
393
|
-
if exists(prev_head_out):
|
394
|
-
h_q = h_q + prev_head_out
|
395
|
-
|
396
|
-
out, intermediates = self.attend(
|
397
|
-
h_q, h_k, h_v,
|
398
|
-
mask = h_mask,
|
399
|
-
attn_bias = h_attn_bias,
|
400
|
-
prev_attn = h_prev_attn
|
401
|
-
)
|
402
|
-
|
403
|
-
prev_head_out = out
|
404
|
-
|
405
|
-
all_outs.append(out)
|
406
|
-
all_intermediates.append(intermediates)
|
407
|
-
|
408
|
-
# cat all output heads
|
409
|
-
|
410
|
-
all_outs = torch.cat(all_outs, dim = 1)
|
411
|
-
|
412
|
-
# cat all intermediates, if they exist
|
413
|
-
|
414
|
-
qk_similarities, pre_softmax_attn, post_softmax_attn = zip(*map(lambda i: i.to_tuple(), all_intermediates))
|
415
|
-
|
416
|
-
qk_similarities, pre_softmax_attn, post_softmax_attn = map(compact, (qk_similarities, pre_softmax_attn, post_softmax_attn))
|
417
|
-
|
418
|
-
aggregated_intermediates = Intermediates(
|
419
|
-
qk_similarities = torch.cat(qk_similarities, dim = 1) if len(qk_similarities) > 0 else None,
|
420
|
-
pre_softmax_attn = torch.cat(pre_softmax_attn, dim = 1) if len(pre_softmax_attn) > 0 else None,
|
421
|
-
post_softmax_attn = torch.cat(post_softmax_attn, dim = 1) if len(post_softmax_attn) > 0 else None
|
422
|
-
)
|
423
|
-
|
424
|
-
return all_outs, aggregated_intermediates
|
@@ -14,7 +14,7 @@ from typing import List, Callable, Optional
|
|
14
14
|
from einops import rearrange, repeat, reduce, pack, unpack
|
15
15
|
from einops.layers.torch import Rearrange
|
16
16
|
|
17
|
-
from x_transformers.attend import Attend, Intermediates
|
17
|
+
from x_transformers.attend import Attend, Intermediates
|
18
18
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
19
19
|
|
20
20
|
# constants
|
@@ -650,6 +650,7 @@ class Attention(nn.Module):
|
|
650
650
|
num_mem_kv = 0,
|
651
651
|
dropout = 0.,
|
652
652
|
on_attn = False,
|
653
|
+
gate_value_heads = False,
|
653
654
|
gate_values = False,
|
654
655
|
zero_init_output = False,
|
655
656
|
max_attend_past = None,
|
@@ -662,7 +663,6 @@ class Attention(nn.Module):
|
|
662
663
|
shared_kv = False,
|
663
664
|
value_dim_head = None,
|
664
665
|
tensor_product = False, # https://arxiv.org/abs/2208.06061
|
665
|
-
cascading_heads = False,
|
666
666
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
667
667
|
rotary_embed_values = False,
|
668
668
|
onnxable = False
|
@@ -674,7 +674,6 @@ class Attention(nn.Module):
|
|
674
674
|
self.causal = causal
|
675
675
|
self.max_attend_past = max_attend_past
|
676
676
|
|
677
|
-
|
678
677
|
assert not (exists(kv_heads) and one_kv_head), 'either attn_one_kv_head is set to True (in which case kv_heads is set to 1), or attn_kv_heads is set, but not both'
|
679
678
|
|
680
679
|
value_dim_head = default(value_dim_head, dim_head)
|
@@ -705,7 +704,14 @@ class Attention(nn.Module):
|
|
705
704
|
if gate_values:
|
706
705
|
self.to_v_gate = nn.Linear(dim, out_dim)
|
707
706
|
nn.init.constant_(self.to_v_gate.weight, 0)
|
708
|
-
nn.init.constant_(self.to_v_gate.bias,
|
707
|
+
nn.init.constant_(self.to_v_gate.bias, 10)
|
708
|
+
|
709
|
+
# add per head gating of the output values, from 'Attend to nothing' paper
|
710
|
+
self.to_v_head_gate = None
|
711
|
+
if gate_value_heads:
|
712
|
+
self.to_v_head_gate = nn.Linear(dim, heads)
|
713
|
+
nn.init.constant_(self.to_v_head_gate.weight, 0)
|
714
|
+
nn.init.constant_(self.to_v_head_gate.bias, 10)
|
709
715
|
|
710
716
|
# cosine sim attention
|
711
717
|
self.qk_norm = qk_norm
|
@@ -738,10 +744,6 @@ class Attention(nn.Module):
|
|
738
744
|
onnxable = onnxable
|
739
745
|
)
|
740
746
|
|
741
|
-
if cascading_heads:
|
742
|
-
# cascading heads - wrap the Attend logic
|
743
|
-
self.attend = CascadingHeads(self.attend)
|
744
|
-
|
745
747
|
# head scaling
|
746
748
|
self.head_scale = head_scale
|
747
749
|
if head_scale:
|
@@ -911,6 +913,12 @@ class Attention(nn.Module):
|
|
911
913
|
if head_scale:
|
912
914
|
out = out * self.head_scale_params
|
913
915
|
|
916
|
+
# per head gating, from https://arxiv.org/abs/2306.12929
|
917
|
+
|
918
|
+
if exists(self.to_v_head_gate):
|
919
|
+
head_gate = self.to_v_head_gate(x)
|
920
|
+
out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
|
921
|
+
|
914
922
|
# merge heads
|
915
923
|
|
916
924
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x-transformers-1.23.0 → x-transformers-1.23.2}/x_transformers/continuous_autoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x-transformers-1.23.0 → x-transformers-1.23.2}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|