x-transformers 2.5.0__py3-none-any.whl → 2.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +15 -4
- {x_transformers-2.5.0.dist-info → x_transformers-2.5.2.dist-info}/METADATA +1 -1
- {x_transformers-2.5.0.dist-info → x_transformers-2.5.2.dist-info}/RECORD +5 -5
- {x_transformers-2.5.0.dist-info → x_transformers-2.5.2.dist-info}/WHEEL +0 -0
- {x_transformers-2.5.0.dist-info → x_transformers-2.5.2.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -2038,6 +2038,9 @@ class AttentionLayers(Module):
|
|
2038
2038
|
self.causal = causal
|
2039
2039
|
self.layers = ModuleList([])
|
2040
2040
|
|
2041
|
+
self.attn_heads = heads
|
2042
|
+
self.attn_dim_head = dim_head
|
2043
|
+
|
2041
2044
|
# routing related
|
2042
2045
|
# 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
|
2043
2046
|
# 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245
|
@@ -2758,6 +2761,9 @@ class AttentionPool(Module):
|
|
2758
2761
|
dim_context = None,
|
2759
2762
|
add_residual = False,
|
2760
2763
|
depth = 1,
|
2764
|
+
heads = 8,
|
2765
|
+
dim_head = 64,
|
2766
|
+
use_transformer_blocks = None,
|
2761
2767
|
squeeze_output = None,
|
2762
2768
|
attn_kwargs: dict = dict()
|
2763
2769
|
):
|
@@ -2767,13 +2773,18 @@ class AttentionPool(Module):
|
|
2767
2773
|
squeeze_output = default(squeeze_output, False)
|
2768
2774
|
assert not (squeeze_output and num_pooled_tokens > 1)
|
2769
2775
|
|
2776
|
+
use_transformer_blocks = default(use_transformer_blocks, depth > 1)
|
2777
|
+
assert use_transformer_blocks or depth == 1
|
2778
|
+
|
2770
2779
|
self.queries = nn.Parameter(torch.randn(num_pooled_tokens, dim) * 1e-2)
|
2771
2780
|
|
2772
|
-
if
|
2781
|
+
if use_transformer_blocks:
|
2773
2782
|
assert not add_residual, 'residual already in effect when doing a full cross attention based transformer for pooling'
|
2774
|
-
|
2783
|
+
attn_kwargs = {f'attn_{k}': v for k, v in attn_kwargs.items()}
|
2784
|
+
|
2785
|
+
self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, heads = heads, attn_dim_head = dim_head, )
|
2775
2786
|
else:
|
2776
|
-
self.pooler = Attention(dim = dim, dim_context = dim_context, **attn_kwargs)
|
2787
|
+
self.pooler = Attention(dim = dim, dim_context = dim_context, heads = heads, dim_head = dim_head, **attn_kwargs)
|
2777
2788
|
|
2778
2789
|
self.add_residual = add_residual
|
2779
2790
|
|
@@ -2999,7 +3010,7 @@ class TransformerWrapper(Module):
|
|
2999
3010
|
self.attn_pool = None
|
3000
3011
|
|
3001
3012
|
if attn_pool:
|
3002
|
-
self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth)
|
3013
|
+
self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth, heads = self.attn_layers.attn_heads, dim_head = self.attn_layers.attn_dim_head)
|
3003
3014
|
|
3004
3015
|
# whether to average pool the embed (`global average pool`)
|
3005
3016
|
|
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
11
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
12
|
-
x_transformers/x_transformers.py,sha256=
|
12
|
+
x_transformers/x_transformers.py,sha256=vmMrHP3hAQ9iAJlRN1pKmXOn7pD3mfh_ndtaR7LMPzU,119860
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
15
|
-
x_transformers-2.5.
|
16
|
-
x_transformers-2.5.
|
17
|
-
x_transformers-2.5.
|
18
|
-
x_transformers-2.5.
|
15
|
+
x_transformers-2.5.2.dist-info/METADATA,sha256=yeferX_PJIv0Lxs36vZSV7Z2w9ol4udiUAON95hP_bY,90223
|
16
|
+
x_transformers-2.5.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
x_transformers-2.5.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
18
|
+
x_transformers-2.5.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|