x-transformers 1.23.1__py3-none-any.whl → 1.23.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -650,6 +650,7 @@ class Attention(nn.Module):
650
650
  num_mem_kv = 0,
651
651
  dropout = 0.,
652
652
  on_attn = False,
653
+ gate_value_heads = False,
653
654
  gate_values = False,
654
655
  zero_init_output = False,
655
656
  max_attend_past = None,
@@ -703,7 +704,14 @@ class Attention(nn.Module):
703
704
  if gate_values:
704
705
  self.to_v_gate = nn.Linear(dim, out_dim)
705
706
  nn.init.constant_(self.to_v_gate.weight, 0)
706
- nn.init.constant_(self.to_v_gate.bias, 1)
707
+ nn.init.constant_(self.to_v_gate.bias, 10)
708
+
709
+ # add per head gating of the output values, from 'Attend to nothing' paper
710
+ self.to_v_head_gate = None
711
+ if gate_value_heads:
712
+ self.to_v_head_gate = nn.Linear(dim, heads)
713
+ nn.init.constant_(self.to_v_head_gate.weight, 0)
714
+ nn.init.constant_(self.to_v_head_gate.bias, 10)
707
715
 
708
716
  # cosine sim attention
709
717
  self.qk_norm = qk_norm
@@ -905,6 +913,12 @@ class Attention(nn.Module):
905
913
  if head_scale:
906
914
  out = out * self.head_scale_params
907
915
 
916
+ # per head gating, from https://arxiv.org/abs/2306.12929
917
+
918
+ if exists(self.to_v_head_gate):
919
+ head_gate = self.to_v_head_gate(x)
920
+ out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
921
+
908
922
  # merge heads
909
923
 
910
924
  out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.23.1
3
+ Version: 1.23.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,10 +3,10 @@ x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,1126
3
3
  x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
4
4
  x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=ikf99q_g1_v_wObZed972s2hHrbiDpAq_qGJDmNcVZc,60573
6
+ x_transformers/x_transformers.py,sha256=KQ9mU_jE27whl6yQI67grF0S8Xhd3GndnM6Yd0-q-lw,61162
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
- x_transformers-1.23.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
- x_transformers-1.23.1.dist-info/METADATA,sha256=91Gu0qU9ztioZ2_oVeOrLeVkP--n6ngneDxEOMUHJe8,661
10
- x_transformers-1.23.1.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
- x_transformers-1.23.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
- x_transformers-1.23.1.dist-info/RECORD,,
8
+ x_transformers-1.23.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
+ x_transformers-1.23.2.dist-info/METADATA,sha256=8h0sbx8-4yNTOJuAZLbe5HQ16hsmZI1M_mT-rMIIMJc,661
10
+ x_transformers-1.23.2.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
+ x_transformers-1.23.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
+ x_transformers-1.23.2.dist-info/RECORD,,