x-transformers 1.18.0__py3-none-any.whl → 1.18.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +6 -4
- x_transformers/x_transformers.py +51 -32
- {x_transformers-1.18.0.dist-info → x_transformers-1.18.2.dist-info}/METADATA +1 -1
- {x_transformers-1.18.0.dist-info → x_transformers-1.18.2.dist-info}/RECORD +7 -7
- {x_transformers-1.18.0.dist-info → x_transformers-1.18.2.dist-info}/WHEEL +1 -1
- {x_transformers-1.18.0.dist-info → x_transformers-1.18.2.dist-info}/LICENSE +0 -0
- {x_transformers-1.18.0.dist-info → x_transformers-1.18.2.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from functools import partial
|
2
|
+
from typing import Optional
|
2
3
|
|
3
4
|
import torch
|
4
5
|
from torch import nn, einsum, Tensor
|
@@ -17,9 +18,9 @@ EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash
|
|
17
18
|
|
18
19
|
@dataclass
|
19
20
|
class Intermediates:
|
20
|
-
qk_similarities: Tensor = None
|
21
|
-
pre_softmax_attn: Tensor = None
|
22
|
-
post_softmax_attn: Tensor = None
|
21
|
+
qk_similarities: Optional[Tensor] = None
|
22
|
+
pre_softmax_attn: Optional[Tensor] = None
|
23
|
+
post_softmax_attn: Optional[Tensor] = None
|
23
24
|
|
24
25
|
def to_tuple(self):
|
25
26
|
return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
|
@@ -256,7 +257,6 @@ class Attend(nn.Module):
|
|
256
257
|
dots = dots + attn_bias
|
257
258
|
|
258
259
|
i, j, dtype = *dots.shape[-2:], dots.dtype
|
259
|
-
pre_softmax_attn = dots.clone()
|
260
260
|
|
261
261
|
mask_value = -torch.finfo(dots.dtype).max
|
262
262
|
|
@@ -272,6 +272,8 @@ class Attend(nn.Module):
|
|
272
272
|
causal_mask = self.create_causal_mask(i, j, device = device)
|
273
273
|
dots = dots.masked_fill(causal_mask, mask_value)
|
274
274
|
|
275
|
+
pre_softmax_attn = dots.clone()
|
276
|
+
|
275
277
|
attn = self.attn_fn(dots, dim = -1)
|
276
278
|
attn = attn.type(dtype)
|
277
279
|
|
x_transformers/x_transformers.py
CHANGED
@@ -9,7 +9,7 @@ from functools import partial, wraps
|
|
9
9
|
from inspect import isfunction
|
10
10
|
from collections import namedtuple
|
11
11
|
from dataclasses import dataclass
|
12
|
-
from typing import List, Callable
|
12
|
+
from typing import List, Callable, Optional
|
13
13
|
|
14
14
|
from einops import rearrange, repeat, reduce
|
15
15
|
from einops.layers.torch import Rearrange
|
@@ -23,8 +23,10 @@ DEFAULT_DIM_HEAD = 64
|
|
23
23
|
|
24
24
|
@dataclass
|
25
25
|
class LayerIntermediates:
|
26
|
-
hiddens: List[Tensor] = None
|
27
|
-
attn_intermediates: List[Intermediates] = None
|
26
|
+
hiddens: Optional[List[Tensor]] = None
|
27
|
+
attn_intermediates: Optional[List[Intermediates]] = None
|
28
|
+
layer_hiddens: Optional[List[Tensor]] = None
|
29
|
+
attn_z_loss: Optional[Tensor] = None
|
28
30
|
|
29
31
|
# helpers
|
30
32
|
|
@@ -89,6 +91,31 @@ def or_reduce(masks):
|
|
89
91
|
head = head | rest
|
90
92
|
return head
|
91
93
|
|
94
|
+
# auxiliary loss helpers
|
95
|
+
|
96
|
+
def calc_z_loss(
|
97
|
+
pre_softmax_attns: List[Tensor],
|
98
|
+
mask = None,
|
99
|
+
weight = 1.
|
100
|
+
):
|
101
|
+
# the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906
|
102
|
+
# in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects
|
103
|
+
# also used in PaLM as one of the measures
|
104
|
+
|
105
|
+
lse = 0.
|
106
|
+
|
107
|
+
for attn in pre_softmax_attns:
|
108
|
+
lse = lse + attn.logsumexp(dim = -1)
|
109
|
+
|
110
|
+
loss = torch.square(lse)
|
111
|
+
loss = reduce(loss, 'b h n -> b n', 'sum')
|
112
|
+
|
113
|
+
if not exists(mask):
|
114
|
+
return loss.mean() * weight
|
115
|
+
|
116
|
+
loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
|
117
|
+
return loss * weight
|
118
|
+
|
92
119
|
# init helpers
|
93
120
|
|
94
121
|
def init_zero_(layer):
|
@@ -373,29 +400,6 @@ class AlibiPositionalBias(nn.Module):
|
|
373
400
|
|
374
401
|
return self.bias
|
375
402
|
|
376
|
-
class LearnedAlibiPositionalBias(AlibiPositionalBias):
|
377
|
-
def __init__(self, heads, total_heads):
|
378
|
-
super().__init__(heads, total_heads)
|
379
|
-
log_slopes = torch.log(self.slopes)
|
380
|
-
self.learned_logslopes = nn.Parameter(log_slopes)
|
381
|
-
|
382
|
-
def forward(self, i, j):
|
383
|
-
h, device = self.heads, self.device
|
384
|
-
|
385
|
-
def get_slopes(param):
|
386
|
-
return pad_at_dim(param.exp(), (0, h - param.shape[0]), dim = -2)
|
387
|
-
|
388
|
-
if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
|
389
|
-
bias = self.bias[..., :i, :j]
|
390
|
-
else:
|
391
|
-
bias = self.get_bias(i, j, device)
|
392
|
-
self.register_buffer('bias', bias, persistent = False)
|
393
|
-
|
394
|
-
slopes = get_slopes(self.learned_logslopes)
|
395
|
-
bias = bias * slopes
|
396
|
-
|
397
|
-
return bias
|
398
|
-
|
399
403
|
class RotaryEmbedding(nn.Module):
|
400
404
|
def __init__(
|
401
405
|
self,
|
@@ -907,7 +911,6 @@ class AttentionLayers(nn.Module):
|
|
907
911
|
use_simple_rmsnorm = False,
|
908
912
|
alibi_pos_bias = False,
|
909
913
|
alibi_num_heads = None,
|
910
|
-
alibi_learned = False,
|
911
914
|
rel_pos_bias = False,
|
912
915
|
rel_pos_num_buckets = 32,
|
913
916
|
rel_pos_max_distance = 128,
|
@@ -979,8 +982,7 @@ class AttentionLayers(nn.Module):
|
|
979
982
|
elif alibi_pos_bias:
|
980
983
|
alibi_num_heads = default(alibi_num_heads, heads)
|
981
984
|
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
982
|
-
|
983
|
-
self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, total_heads = heads)
|
985
|
+
self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
|
984
986
|
|
985
987
|
# determine deepnorm and residual scale
|
986
988
|
|
@@ -1135,7 +1137,9 @@ class AttentionLayers(nn.Module):
|
|
1135
1137
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
1136
1138
|
|
1137
1139
|
hiddens = []
|
1140
|
+
layer_hiddens = []
|
1138
1141
|
intermediates = []
|
1142
|
+
|
1139
1143
|
prev_attn = None
|
1140
1144
|
prev_cross_attn = None
|
1141
1145
|
|
@@ -1165,6 +1169,9 @@ class AttentionLayers(nn.Module):
|
|
1165
1169
|
|
1166
1170
|
inner_residual = x
|
1167
1171
|
|
1172
|
+
if return_hiddens:
|
1173
|
+
layer_hiddens.append(x)
|
1174
|
+
|
1168
1175
|
pre_norm, post_branch_norm, post_main_norm = norm
|
1169
1176
|
|
1170
1177
|
if exists(pre_norm):
|
@@ -1196,6 +1203,9 @@ class AttentionLayers(nn.Module):
|
|
1196
1203
|
if exists(post_main_norm):
|
1197
1204
|
x = post_main_norm(x)
|
1198
1205
|
|
1206
|
+
if return_hiddens:
|
1207
|
+
layer_hiddens.append(x)
|
1208
|
+
|
1199
1209
|
if self.resi_dual:
|
1200
1210
|
x = x + self.final_norm(outer_residual)
|
1201
1211
|
else:
|
@@ -1204,7 +1214,8 @@ class AttentionLayers(nn.Module):
|
|
1204
1214
|
if return_hiddens:
|
1205
1215
|
intermediates = LayerIntermediates(
|
1206
1216
|
hiddens = hiddens,
|
1207
|
-
attn_intermediates = intermediates
|
1217
|
+
attn_intermediates = intermediates,
|
1218
|
+
layer_hiddens = layer_hiddens
|
1208
1219
|
)
|
1209
1220
|
|
1210
1221
|
return x, intermediates
|
@@ -1303,7 +1314,8 @@ class TransformerWrapper(nn.Module):
|
|
1303
1314
|
use_abs_pos_emb = True,
|
1304
1315
|
scaled_sinu_pos_emb = False,
|
1305
1316
|
l2norm_embed = False,
|
1306
|
-
emb_frac_gradient = 1
|
1317
|
+
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
|
1318
|
+
attn_z_loss_weight = 1e-4
|
1307
1319
|
):
|
1308
1320
|
super().__init__()
|
1309
1321
|
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
@@ -1368,10 +1380,12 @@ class TransformerWrapper(nn.Module):
|
|
1368
1380
|
pos = None,
|
1369
1381
|
prepend_embeds = None,
|
1370
1382
|
sum_embeds = None,
|
1383
|
+
return_attn_z_loss = False,
|
1384
|
+
attn_z_loss_weight = 1e-4,
|
1371
1385
|
**kwargs
|
1372
1386
|
):
|
1373
1387
|
b, n, device, num_mem, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.emb_frac_gradient
|
1374
|
-
return_hiddens = return_mems | return_attn
|
1388
|
+
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
|
1375
1389
|
|
1376
1390
|
# absolute positional embedding
|
1377
1391
|
|
@@ -1434,6 +1448,11 @@ class TransformerWrapper(nn.Module):
|
|
1434
1448
|
else:
|
1435
1449
|
out = self.to_logits(x)
|
1436
1450
|
|
1451
|
+
if return_attn_z_loss:
|
1452
|
+
pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
|
1453
|
+
intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
|
1454
|
+
return_intermediates = True
|
1455
|
+
|
1437
1456
|
if return_intermediates:
|
1438
1457
|
return out, intermediates
|
1439
1458
|
|
@@ -1,12 +1,12 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=DRrcDq3ZwmGfuZ8hjLvKB-IJwPjnnVTGX76ONrlFSA4,11913
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=LuxuSSmZNXivv1XMijOTpNfPd6_uMemptsufL8l1Y7I,4514
|
4
4
|
x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=BQtmUlLNGvaB6Xn5zK03dpO2OQT1jwpEODUTHZJz_ps,56443
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=-CAYjTtqrks8ZTxjYm2stOelZpU4MbZIvLjUxWO0P9Y,3988
|
8
|
-
x_transformers-1.18.
|
9
|
-
x_transformers-1.18.
|
10
|
-
x_transformers-1.18.
|
11
|
-
x_transformers-1.18.
|
12
|
-
x_transformers-1.18.
|
8
|
+
x_transformers-1.18.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
+
x_transformers-1.18.2.dist-info/METADATA,sha256=3F1KCGDFqziINFhFhKLciVog2fnuOhigwAnyG49fb3A,661
|
10
|
+
x_transformers-1.18.2.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
11
|
+
x_transformers-1.18.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
+
x_transformers-1.18.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|