x-transformers 1.23.1__py3-none-any.whl → 1.23.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +43 -2
- {x_transformers-1.23.1.dist-info → x_transformers-1.23.3.dist-info}/METADATA +1 -1
- {x_transformers-1.23.1.dist-info → x_transformers-1.23.3.dist-info}/RECORD +6 -6
- {x_transformers-1.23.1.dist-info → x_transformers-1.23.3.dist-info}/LICENSE +0 -0
- {x_transformers-1.23.1.dist-info → x_transformers-1.23.3.dist-info}/WHEEL +0 -0
- {x_transformers-1.23.1.dist-info → x_transformers-1.23.3.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -650,6 +650,7 @@ class Attention(nn.Module):
|
|
650
650
|
num_mem_kv = 0,
|
651
651
|
dropout = 0.,
|
652
652
|
on_attn = False,
|
653
|
+
gate_value_heads = False,
|
653
654
|
gate_values = False,
|
654
655
|
zero_init_output = False,
|
655
656
|
max_attend_past = None,
|
@@ -703,7 +704,14 @@ class Attention(nn.Module):
|
|
703
704
|
if gate_values:
|
704
705
|
self.to_v_gate = nn.Linear(dim, out_dim)
|
705
706
|
nn.init.constant_(self.to_v_gate.weight, 0)
|
706
|
-
nn.init.constant_(self.to_v_gate.bias,
|
707
|
+
nn.init.constant_(self.to_v_gate.bias, 10)
|
708
|
+
|
709
|
+
# add per head gating of the output values, from 'Attend to nothing' paper
|
710
|
+
self.to_v_head_gate = None
|
711
|
+
if gate_value_heads:
|
712
|
+
self.to_v_head_gate = nn.Linear(dim, heads)
|
713
|
+
nn.init.constant_(self.to_v_head_gate.weight, 0)
|
714
|
+
nn.init.constant_(self.to_v_head_gate.bias, 10)
|
707
715
|
|
708
716
|
# cosine sim attention
|
709
717
|
self.qk_norm = qk_norm
|
@@ -905,6 +913,12 @@ class Attention(nn.Module):
|
|
905
913
|
if head_scale:
|
906
914
|
out = out * self.head_scale_params
|
907
915
|
|
916
|
+
# per head gating, from https://arxiv.org/abs/2306.12929
|
917
|
+
|
918
|
+
if exists(self.to_v_head_gate):
|
919
|
+
head_gate = self.to_v_head_gate(x)
|
920
|
+
out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
|
921
|
+
|
908
922
|
# merge heads
|
909
923
|
|
910
924
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
@@ -1608,6 +1622,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
1608
1622
|
dim_out = None,
|
1609
1623
|
emb_dim = None,
|
1610
1624
|
max_mem_len = 0,
|
1625
|
+
num_memory_tokens = None,
|
1611
1626
|
post_emb_norm = False,
|
1612
1627
|
emb_dropout = 0.,
|
1613
1628
|
use_abs_pos_emb = True,
|
@@ -1632,10 +1647,21 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
1632
1647
|
self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
|
1633
1648
|
self.emb_dropout = nn.Dropout(emb_dropout)
|
1634
1649
|
|
1635
|
-
|
1650
|
+
# memory tokens
|
1651
|
+
|
1652
|
+
num_memory_tokens = default(num_memory_tokens, 0)
|
1653
|
+
self.has_memory_tokens = num_memory_tokens > 0
|
1654
|
+
|
1655
|
+
if num_memory_tokens > 0:
|
1656
|
+
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
1657
|
+
|
1658
|
+
# attention layers
|
1636
1659
|
|
1637
1660
|
self.attn_layers = attn_layers
|
1638
1661
|
|
1662
|
+
# project in and out
|
1663
|
+
|
1664
|
+
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
1639
1665
|
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
1640
1666
|
|
1641
1667
|
def forward(
|
@@ -1651,11 +1677,19 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
1651
1677
|
prepend_embeds = None,
|
1652
1678
|
**kwargs
|
1653
1679
|
):
|
1680
|
+
batch = x.shape[0]
|
1681
|
+
|
1654
1682
|
x = self.project_in(x)
|
1655
1683
|
x = x + self.pos_emb(x, pos = pos)
|
1656
1684
|
|
1657
1685
|
x = self.post_emb_norm(x)
|
1658
1686
|
|
1687
|
+
# memory tokens
|
1688
|
+
|
1689
|
+
if self.has_memory_tokens:
|
1690
|
+
m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
|
1691
|
+
x, mem_ps = pack([m, x], 'b * d')
|
1692
|
+
|
1659
1693
|
# whether to append embeds, as in PaLI, for image embeddings
|
1660
1694
|
|
1661
1695
|
if exists(prepend_embeds):
|
@@ -1666,8 +1700,15 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
1666
1700
|
|
1667
1701
|
x = self.emb_dropout(x)
|
1668
1702
|
|
1703
|
+
# attention layers
|
1704
|
+
|
1669
1705
|
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
|
1670
1706
|
|
1707
|
+
# splice out memory tokens
|
1708
|
+
|
1709
|
+
if self.has_memory_tokens:
|
1710
|
+
m, x = unpack(x, mem_ps, 'b * d')
|
1711
|
+
|
1671
1712
|
out = self.project_out(x) if not return_embeddings else x
|
1672
1713
|
|
1673
1714
|
if return_intermediates:
|
@@ -3,10 +3,10 @@ x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,1126
|
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
|
4
4
|
x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=goudsIa79mfyJtzuI0GqTSdGQ5CXG1ga5Is9h3UBC5Y,61861
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
|
-
x_transformers-1.23.
|
9
|
-
x_transformers-1.23.
|
10
|
-
x_transformers-1.23.
|
11
|
-
x_transformers-1.23.
|
12
|
-
x_transformers-1.23.
|
8
|
+
x_transformers-1.23.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
+
x_transformers-1.23.3.dist-info/METADATA,sha256=SXNDjqYSGkklnbXVRg8S52VxDR6VVO62KvRH60abY_k,661
|
10
|
+
x_transformers-1.23.3.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
11
|
+
x_transformers-1.23.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
+
x_transformers-1.23.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|