x-transformers 1.32.7__py3-none-any.whl → 1.32.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +13 -2
- x_transformers/x_transformers.py +27 -1
- {x_transformers-1.32.7.dist-info → x_transformers-1.32.8.dist-info}/METADATA +1 -1
- {x_transformers-1.32.7.dist-info → x_transformers-1.32.8.dist-info}/RECORD +7 -7
- {x_transformers-1.32.7.dist-info → x_transformers-1.32.8.dist-info}/LICENSE +0 -0
- {x_transformers-1.32.7.dist-info → x_transformers-1.32.8.dist-info}/WHEEL +0 -0
- {x_transformers-1.32.7.dist-info → x_transformers-1.32.8.dist-info}/top_level.txt +0 -0
x_transformers/continuous.py
CHANGED
@@ -9,6 +9,7 @@ from x_transformers.x_transformers import (
|
|
9
9
|
ScaledSinusoidalEmbedding,
|
10
10
|
AbsolutePositionalEmbedding,
|
11
11
|
LayerNorm,
|
12
|
+
masked_mean,
|
12
13
|
always,
|
13
14
|
pad_at_dim
|
14
15
|
)
|
@@ -39,7 +40,8 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
39
40
|
post_emb_norm = False,
|
40
41
|
emb_dropout = 0.,
|
41
42
|
use_abs_pos_emb = True,
|
42
|
-
scaled_sinu_pos_emb = False
|
43
|
+
scaled_sinu_pos_emb = False,
|
44
|
+
average_pool_embed = False
|
43
45
|
):
|
44
46
|
super().__init__()
|
45
47
|
dim = attn_layers.dim
|
@@ -72,6 +74,10 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
72
74
|
|
73
75
|
self.attn_layers = attn_layers
|
74
76
|
|
77
|
+
# average pool
|
78
|
+
|
79
|
+
self.average_pool_embed = average_pool_embed
|
80
|
+
|
75
81
|
# project in and out
|
76
82
|
|
77
83
|
self.project_in = nn.Linear(dim_in, dim, bias = False) if exists(dim_in) else nn.Identity()
|
@@ -92,7 +98,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
92
98
|
prepend_mask = None,
|
93
99
|
**kwargs
|
94
100
|
):
|
95
|
-
batch, seq, device = *x.shape[:2], x.device
|
101
|
+
batch, seq, orig_mask, device = *x.shape[:2], mask, x.device
|
96
102
|
|
97
103
|
x = self.project_in(x)
|
98
104
|
x = x + self.pos_emb(x, pos = pos)
|
@@ -136,6 +142,11 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
136
142
|
m, x = unpack(x, mem_ps, 'b * d')
|
137
143
|
intermediates.memory_tokens = m
|
138
144
|
|
145
|
+
if self.average_pool_embed:
|
146
|
+
x = masked_mean(x, mask = orig_mask)
|
147
|
+
|
148
|
+
# maybe linear project out
|
149
|
+
|
139
150
|
out = self.project_out(x) if not return_embeddings else x
|
140
151
|
|
141
152
|
if return_intermediates:
|
x_transformers/x_transformers.py
CHANGED
@@ -99,6 +99,17 @@ def l2norm(t, groups = 1):
|
|
99
99
|
def softclamp(t, value):
|
100
100
|
return (t / value).tanh() * value
|
101
101
|
|
102
|
+
def masked_mean(t, mask = None, dim = 1):
|
103
|
+
if not exists(mask):
|
104
|
+
return t.mean(dim = dim)
|
105
|
+
|
106
|
+
dims_append = (1,) * (t.ndim - mask.ndim)
|
107
|
+
mask = mask.reshape(*mask.shape, *dims_append)
|
108
|
+
|
109
|
+
num = (t * mask).sum(dim = dim)
|
110
|
+
den = mask.sum(dim = dim).clamp(min = 1.)
|
111
|
+
return num / den
|
112
|
+
|
102
113
|
def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.):
|
103
114
|
if pad == (0, 0):
|
104
115
|
return t
|
@@ -1666,6 +1677,8 @@ class AttentionLayers(Module):
|
|
1666
1677
|
|
1667
1678
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
1668
1679
|
|
1680
|
+
first_skip = None
|
1681
|
+
|
1669
1682
|
# go through the attention and feedforward layers
|
1670
1683
|
|
1671
1684
|
for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
|
@@ -1674,6 +1687,9 @@ class AttentionLayers(Module):
|
|
1674
1687
|
if self.training and layer_dropout > 0. and random() < layer_dropout:
|
1675
1688
|
continue
|
1676
1689
|
|
1690
|
+
if ind == 1:
|
1691
|
+
first_skip = x.clone()
|
1692
|
+
|
1677
1693
|
if layer_type == 'a':
|
1678
1694
|
if return_hiddens:
|
1679
1695
|
hiddens.append(x)
|
@@ -1909,6 +1925,7 @@ class TransformerWrapper(Module):
|
|
1909
1925
|
l2norm_embed = False,
|
1910
1926
|
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
|
1911
1927
|
attn_z_loss_weight = 1e-4,
|
1928
|
+
average_pool_embed = False
|
1912
1929
|
):
|
1913
1930
|
super().__init__()
|
1914
1931
|
|
@@ -1954,6 +1971,10 @@ class TransformerWrapper(Module):
|
|
1954
1971
|
|
1955
1972
|
assert num_output_heads > 0
|
1956
1973
|
|
1974
|
+
# whether to average pool the embed (`global average pool`)
|
1975
|
+
|
1976
|
+
self.average_pool_embed = average_pool_embed
|
1977
|
+
|
1957
1978
|
# output head, usually to logits of num_tokens
|
1958
1979
|
|
1959
1980
|
logits_dim = default(logits_dim, num_tokens)
|
@@ -2015,7 +2036,7 @@ class TransformerWrapper(Module):
|
|
2015
2036
|
cache: LayerIntermediates | None = None,
|
2016
2037
|
**kwargs
|
2017
2038
|
):
|
2018
|
-
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
|
2039
|
+
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
|
2019
2040
|
|
2020
2041
|
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
|
2021
2042
|
return_embeddings = return_embeddings | (not exists(self.to_logits))
|
@@ -2118,6 +2139,11 @@ class TransformerWrapper(Module):
|
|
2118
2139
|
|
2119
2140
|
x = x[:, :n]
|
2120
2141
|
|
2142
|
+
# global average pool
|
2143
|
+
|
2144
|
+
if self.average_pool_embed:
|
2145
|
+
x = masked_mean(x, mask = orig_mask, dim = 1)
|
2146
|
+
|
2121
2147
|
# projecting to logits
|
2122
2148
|
|
2123
2149
|
if not return_embeddings:
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
2
|
x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,11349
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
|
-
x_transformers/continuous.py,sha256=
|
4
|
+
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=cq364zjUVvGEeFxdu703yI2tp1VhpxTIpLTgMshHpzI,77392
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.32.
|
12
|
-
x_transformers-1.32.
|
13
|
-
x_transformers-1.32.
|
14
|
-
x_transformers-1.32.
|
15
|
-
x_transformers-1.32.
|
11
|
+
x_transformers-1.32.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.32.8.dist-info/METADATA,sha256=tEHQVjqqXKQ2eSd-j5yrmfubkjuQZga_sT_5XgHanQo,661
|
13
|
+
x_transformers-1.32.8.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
14
|
+
x_transformers-1.32.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.32.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|