x-transformers 2.5.1__py3-none-any.whl → 2.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +9 -1
- {x_transformers-2.5.1.dist-info → x_transformers-2.5.3.dist-info}/METADATA +1 -1
- {x_transformers-2.5.1.dist-info → x_transformers-2.5.3.dist-info}/RECORD +5 -5
- {x_transformers-2.5.1.dist-info → x_transformers-2.5.3.dist-info}/WHEEL +0 -0
- {x_transformers-2.5.1.dist-info → x_transformers-2.5.3.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -2763,6 +2763,7 @@ class AttentionPool(Module):
|
|
2763
2763
|
depth = 1,
|
2764
2764
|
heads = 8,
|
2765
2765
|
dim_head = 64,
|
2766
|
+
use_transformer_blocks = None,
|
2766
2767
|
squeeze_output = None,
|
2767
2768
|
attn_kwargs: dict = dict()
|
2768
2769
|
):
|
@@ -2772,9 +2773,12 @@ class AttentionPool(Module):
|
|
2772
2773
|
squeeze_output = default(squeeze_output, False)
|
2773
2774
|
assert not (squeeze_output and num_pooled_tokens > 1)
|
2774
2775
|
|
2776
|
+
use_transformer_blocks = default(use_transformer_blocks, depth > 1)
|
2777
|
+
assert use_transformer_blocks or depth == 1
|
2778
|
+
|
2775
2779
|
self.queries = nn.Parameter(torch.randn(num_pooled_tokens, dim) * 1e-2)
|
2776
2780
|
|
2777
|
-
if
|
2781
|
+
if use_transformer_blocks:
|
2778
2782
|
assert not add_residual, 'residual already in effect when doing a full cross attention based transformer for pooling'
|
2779
2783
|
attn_kwargs = {f'attn_{k}': v for k, v in attn_kwargs.items()}
|
2780
2784
|
|
@@ -2783,6 +2787,7 @@ class AttentionPool(Module):
|
|
2783
2787
|
self.pooler = Attention(dim = dim, dim_context = dim_context, heads = heads, dim_head = dim_head, **attn_kwargs)
|
2784
2788
|
|
2785
2789
|
self.add_residual = add_residual
|
2790
|
+
self.squeeze_output = squeeze_output
|
2786
2791
|
|
2787
2792
|
def forward(self, context, mask = None):
|
2788
2793
|
batch = context.shape[0]
|
@@ -2794,6 +2799,9 @@ class AttentionPool(Module):
|
|
2794
2799
|
if self.add_residual:
|
2795
2800
|
pooled = pooled + queries
|
2796
2801
|
|
2802
|
+
if self.squeeze_output:
|
2803
|
+
pooled = rearrange(pooled, 'b 1 d -> b d')
|
2804
|
+
|
2797
2805
|
return pooled
|
2798
2806
|
|
2799
2807
|
class ViTransformerWrapper(Module):
|
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
11
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
12
|
-
x_transformers/x_transformers.py,sha256=
|
12
|
+
x_transformers/x_transformers.py,sha256=fW-AoomNCw4n2JFbZN9rZV3lKQvz_Tl6L4txUvac_9o,119993
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
15
|
-
x_transformers-2.5.
|
16
|
-
x_transformers-2.5.
|
17
|
-
x_transformers-2.5.
|
18
|
-
x_transformers-2.5.
|
15
|
+
x_transformers-2.5.3.dist-info/METADATA,sha256=iR77ECuqz3O70zaZ5Mx3NwbNNal-FerMlKPlXTbv8vE,90223
|
16
|
+
x_transformers-2.5.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
x_transformers-2.5.3.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
18
|
+
x_transformers-2.5.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|