x-transformers 1.23.1__py3-none-any.whl → 1.23.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +15 -1
- {x_transformers-1.23.1.dist-info → x_transformers-1.23.2.dist-info}/METADATA +1 -1
- {x_transformers-1.23.1.dist-info → x_transformers-1.23.2.dist-info}/RECORD +6 -6
- {x_transformers-1.23.1.dist-info → x_transformers-1.23.2.dist-info}/LICENSE +0 -0
- {x_transformers-1.23.1.dist-info → x_transformers-1.23.2.dist-info}/WHEEL +0 -0
- {x_transformers-1.23.1.dist-info → x_transformers-1.23.2.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -650,6 +650,7 @@ class Attention(nn.Module):
|
|
650
650
|
num_mem_kv = 0,
|
651
651
|
dropout = 0.,
|
652
652
|
on_attn = False,
|
653
|
+
gate_value_heads = False,
|
653
654
|
gate_values = False,
|
654
655
|
zero_init_output = False,
|
655
656
|
max_attend_past = None,
|
@@ -703,7 +704,14 @@ class Attention(nn.Module):
|
|
703
704
|
if gate_values:
|
704
705
|
self.to_v_gate = nn.Linear(dim, out_dim)
|
705
706
|
nn.init.constant_(self.to_v_gate.weight, 0)
|
706
|
-
nn.init.constant_(self.to_v_gate.bias,
|
707
|
+
nn.init.constant_(self.to_v_gate.bias, 10)
|
708
|
+
|
709
|
+
# add per head gating of the output values, from 'Attend to nothing' paper
|
710
|
+
self.to_v_head_gate = None
|
711
|
+
if gate_value_heads:
|
712
|
+
self.to_v_head_gate = nn.Linear(dim, heads)
|
713
|
+
nn.init.constant_(self.to_v_head_gate.weight, 0)
|
714
|
+
nn.init.constant_(self.to_v_head_gate.bias, 10)
|
707
715
|
|
708
716
|
# cosine sim attention
|
709
717
|
self.qk_norm = qk_norm
|
@@ -905,6 +913,12 @@ class Attention(nn.Module):
|
|
905
913
|
if head_scale:
|
906
914
|
out = out * self.head_scale_params
|
907
915
|
|
916
|
+
# per head gating, from https://arxiv.org/abs/2306.12929
|
917
|
+
|
918
|
+
if exists(self.to_v_head_gate):
|
919
|
+
head_gate = self.to_v_head_gate(x)
|
920
|
+
out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
|
921
|
+
|
908
922
|
# merge heads
|
909
923
|
|
910
924
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
@@ -3,10 +3,10 @@ x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,1126
|
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
|
4
4
|
x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=KQ9mU_jE27whl6yQI67grF0S8Xhd3GndnM6Yd0-q-lw,61162
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
|
-
x_transformers-1.23.
|
9
|
-
x_transformers-1.23.
|
10
|
-
x_transformers-1.23.
|
11
|
-
x_transformers-1.23.
|
12
|
-
x_transformers-1.23.
|
8
|
+
x_transformers-1.23.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
+
x_transformers-1.23.2.dist-info/METADATA,sha256=8h0sbx8-4yNTOJuAZLbe5HQ16hsmZI1M_mT-rMIIMJc,661
|
10
|
+
x_transformers-1.23.2.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
11
|
+
x_transformers-1.23.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
+
x_transformers-1.23.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|