x-transformers 2.5.0__py3-none-any.whl → 2.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2038,6 +2038,9 @@ class AttentionLayers(Module):
2038
2038
  self.causal = causal
2039
2039
  self.layers = ModuleList([])
2040
2040
 
2041
+ self.attn_heads = heads
2042
+ self.attn_dim_head = dim_head
2043
+
2041
2044
  # routing related
2042
2045
  # 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
2043
2046
  # 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245
@@ -2758,6 +2761,9 @@ class AttentionPool(Module):
2758
2761
  dim_context = None,
2759
2762
  add_residual = False,
2760
2763
  depth = 1,
2764
+ heads = 8,
2765
+ dim_head = 64,
2766
+ use_transformer_blocks = None,
2761
2767
  squeeze_output = None,
2762
2768
  attn_kwargs: dict = dict()
2763
2769
  ):
@@ -2767,13 +2773,18 @@ class AttentionPool(Module):
2767
2773
  squeeze_output = default(squeeze_output, False)
2768
2774
  assert not (squeeze_output and num_pooled_tokens > 1)
2769
2775
 
2776
+ use_transformer_blocks = default(use_transformer_blocks, depth > 1)
2777
+ assert use_transformer_blocks or depth == 1
2778
+
2770
2779
  self.queries = nn.Parameter(torch.randn(num_pooled_tokens, dim) * 1e-2)
2771
2780
 
2772
- if depth > 1:
2781
+ if use_transformer_blocks:
2773
2782
  assert not add_residual, 'residual already in effect when doing a full cross attention based transformer for pooling'
2774
- self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, **attn_kwargs)
2783
+ attn_kwargs = {f'attn_{k}': v for k, v in attn_kwargs.items()}
2784
+
2785
+ self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, heads = heads, attn_dim_head = dim_head, )
2775
2786
  else:
2776
- self.pooler = Attention(dim = dim, dim_context = dim_context, **attn_kwargs)
2787
+ self.pooler = Attention(dim = dim, dim_context = dim_context, heads = heads, dim_head = dim_head, **attn_kwargs)
2777
2788
 
2778
2789
  self.add_residual = add_residual
2779
2790
 
@@ -2999,7 +3010,7 @@ class TransformerWrapper(Module):
2999
3010
  self.attn_pool = None
3000
3011
 
3001
3012
  if attn_pool:
3002
- self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth)
3013
+ self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth, heads = self.attn_layers.attn_heads, dim_head = self.attn_layers.attn_dim_head)
3003
3014
 
3004
3015
  # whether to average pool the embed (`global average pool`)
3005
3016
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.5.0
3
+ Version: 2.5.2
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
11
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
12
- x_transformers/x_transformers.py,sha256=NOTTbqDk5qZEY2MPpdIwJv4BvGGhXt_nIffrgQDXTf4,119346
12
+ x_transformers/x_transformers.py,sha256=vmMrHP3hAQ9iAJlRN1pKmXOn7pD3mfh_ndtaR7LMPzU,119860
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.5.0.dist-info/METADATA,sha256=ZcxK61msWcsm8vEPG-FirnPTnwG6HNKtvy2ZoLluJHM,90223
16
- x_transformers-2.5.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.5.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.5.0.dist-info/RECORD,,
15
+ x_transformers-2.5.2.dist-info/METADATA,sha256=yeferX_PJIv0Lxs36vZSV7Z2w9ol4udiUAON95hP_bY,90223
16
+ x_transformers-2.5.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.5.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.5.2.dist-info/RECORD,,