x-transformers 1.32.7__py3-none-any.whl → 1.32.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +13 -2
- x_transformers/x_transformers.py +22 -1
- {x_transformers-1.32.7.dist-info → x_transformers-1.32.9.dist-info}/METADATA +1 -1
- {x_transformers-1.32.7.dist-info → x_transformers-1.32.9.dist-info}/RECORD +7 -7
- {x_transformers-1.32.7.dist-info → x_transformers-1.32.9.dist-info}/LICENSE +0 -0
- {x_transformers-1.32.7.dist-info → x_transformers-1.32.9.dist-info}/WHEEL +0 -0
- {x_transformers-1.32.7.dist-info → x_transformers-1.32.9.dist-info}/top_level.txt +0 -0
x_transformers/continuous.py
CHANGED
@@ -9,6 +9,7 @@ from x_transformers.x_transformers import (
|
|
9
9
|
ScaledSinusoidalEmbedding,
|
10
10
|
AbsolutePositionalEmbedding,
|
11
11
|
LayerNorm,
|
12
|
+
masked_mean,
|
12
13
|
always,
|
13
14
|
pad_at_dim
|
14
15
|
)
|
@@ -39,7 +40,8 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
39
40
|
post_emb_norm = False,
|
40
41
|
emb_dropout = 0.,
|
41
42
|
use_abs_pos_emb = True,
|
42
|
-
scaled_sinu_pos_emb = False
|
43
|
+
scaled_sinu_pos_emb = False,
|
44
|
+
average_pool_embed = False
|
43
45
|
):
|
44
46
|
super().__init__()
|
45
47
|
dim = attn_layers.dim
|
@@ -72,6 +74,10 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
72
74
|
|
73
75
|
self.attn_layers = attn_layers
|
74
76
|
|
77
|
+
# average pool
|
78
|
+
|
79
|
+
self.average_pool_embed = average_pool_embed
|
80
|
+
|
75
81
|
# project in and out
|
76
82
|
|
77
83
|
self.project_in = nn.Linear(dim_in, dim, bias = False) if exists(dim_in) else nn.Identity()
|
@@ -92,7 +98,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
92
98
|
prepend_mask = None,
|
93
99
|
**kwargs
|
94
100
|
):
|
95
|
-
batch, seq, device = *x.shape[:2], x.device
|
101
|
+
batch, seq, orig_mask, device = *x.shape[:2], mask, x.device
|
96
102
|
|
97
103
|
x = self.project_in(x)
|
98
104
|
x = x + self.pos_emb(x, pos = pos)
|
@@ -136,6 +142,11 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
136
142
|
m, x = unpack(x, mem_ps, 'b * d')
|
137
143
|
intermediates.memory_tokens = m
|
138
144
|
|
145
|
+
if self.average_pool_embed:
|
146
|
+
x = masked_mean(x, mask = orig_mask)
|
147
|
+
|
148
|
+
# maybe linear project out
|
149
|
+
|
139
150
|
out = self.project_out(x) if not return_embeddings else x
|
140
151
|
|
141
152
|
if return_intermediates:
|
x_transformers/x_transformers.py
CHANGED
@@ -99,6 +99,17 @@ def l2norm(t, groups = 1):
|
|
99
99
|
def softclamp(t, value):
|
100
100
|
return (t / value).tanh() * value
|
101
101
|
|
102
|
+
def masked_mean(t, mask = None, dim = 1):
|
103
|
+
if not exists(mask):
|
104
|
+
return t.mean(dim = dim)
|
105
|
+
|
106
|
+
dims_append = (1,) * (t.ndim - mask.ndim)
|
107
|
+
mask = mask.reshape(*mask.shape, *dims_append)
|
108
|
+
|
109
|
+
num = (t * mask).sum(dim = dim)
|
110
|
+
den = mask.sum(dim = dim).clamp(min = 1.)
|
111
|
+
return num / den
|
112
|
+
|
102
113
|
def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.):
|
103
114
|
if pad == (0, 0):
|
104
115
|
return t
|
@@ -1909,6 +1920,7 @@ class TransformerWrapper(Module):
|
|
1909
1920
|
l2norm_embed = False,
|
1910
1921
|
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
|
1911
1922
|
attn_z_loss_weight = 1e-4,
|
1923
|
+
average_pool_embed = False
|
1912
1924
|
):
|
1913
1925
|
super().__init__()
|
1914
1926
|
|
@@ -1954,6 +1966,10 @@ class TransformerWrapper(Module):
|
|
1954
1966
|
|
1955
1967
|
assert num_output_heads > 0
|
1956
1968
|
|
1969
|
+
# whether to average pool the embed (`global average pool`)
|
1970
|
+
|
1971
|
+
self.average_pool_embed = average_pool_embed
|
1972
|
+
|
1957
1973
|
# output head, usually to logits of num_tokens
|
1958
1974
|
|
1959
1975
|
logits_dim = default(logits_dim, num_tokens)
|
@@ -2015,7 +2031,7 @@ class TransformerWrapper(Module):
|
|
2015
2031
|
cache: LayerIntermediates | None = None,
|
2016
2032
|
**kwargs
|
2017
2033
|
):
|
2018
|
-
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
|
2034
|
+
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
|
2019
2035
|
|
2020
2036
|
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
|
2021
2037
|
return_embeddings = return_embeddings | (not exists(self.to_logits))
|
@@ -2118,6 +2134,11 @@ class TransformerWrapper(Module):
|
|
2118
2134
|
|
2119
2135
|
x = x[:, :n]
|
2120
2136
|
|
2137
|
+
# global average pool
|
2138
|
+
|
2139
|
+
if self.average_pool_embed:
|
2140
|
+
x = masked_mean(x, mask = orig_mask, dim = 1)
|
2141
|
+
|
2121
2142
|
# projecting to logits
|
2122
2143
|
|
2123
2144
|
if not return_embeddings:
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
2
|
x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,11349
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
|
-
x_transformers/continuous.py,sha256=
|
4
|
+
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=8558TPHcDxWUvJYz01EdeyZl0lkHB14bzlsEMwSMPyw,77300
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.32.
|
12
|
-
x_transformers-1.32.
|
13
|
-
x_transformers-1.32.
|
14
|
-
x_transformers-1.32.
|
15
|
-
x_transformers-1.32.
|
11
|
+
x_transformers-1.32.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.32.9.dist-info/METADATA,sha256=-GidCdPhcKpZ49ElbeuJUPko5LZZP_vyEodaN_P3g48,661
|
13
|
+
x_transformers-1.32.9.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
14
|
+
x_transformers-1.32.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.32.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|