x-transformers 1.18.1__py3-none-any.whl → 1.18.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +6 -4
- x_transformers/x_transformers.py +40 -6
- {x_transformers-1.18.1.dist-info → x_transformers-1.18.2.dist-info}/METADATA +1 -1
- {x_transformers-1.18.1.dist-info → x_transformers-1.18.2.dist-info}/RECORD +7 -7
- {x_transformers-1.18.1.dist-info → x_transformers-1.18.2.dist-info}/LICENSE +0 -0
- {x_transformers-1.18.1.dist-info → x_transformers-1.18.2.dist-info}/WHEEL +0 -0
- {x_transformers-1.18.1.dist-info → x_transformers-1.18.2.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from functools import partial
|
2
|
+
from typing import Optional
|
2
3
|
|
3
4
|
import torch
|
4
5
|
from torch import nn, einsum, Tensor
|
@@ -17,9 +18,9 @@ EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash
|
|
17
18
|
|
18
19
|
@dataclass
|
19
20
|
class Intermediates:
|
20
|
-
qk_similarities: Tensor = None
|
21
|
-
pre_softmax_attn: Tensor = None
|
22
|
-
post_softmax_attn: Tensor = None
|
21
|
+
qk_similarities: Optional[Tensor] = None
|
22
|
+
pre_softmax_attn: Optional[Tensor] = None
|
23
|
+
post_softmax_attn: Optional[Tensor] = None
|
23
24
|
|
24
25
|
def to_tuple(self):
|
25
26
|
return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
|
@@ -256,7 +257,6 @@ class Attend(nn.Module):
|
|
256
257
|
dots = dots + attn_bias
|
257
258
|
|
258
259
|
i, j, dtype = *dots.shape[-2:], dots.dtype
|
259
|
-
pre_softmax_attn = dots.clone()
|
260
260
|
|
261
261
|
mask_value = -torch.finfo(dots.dtype).max
|
262
262
|
|
@@ -272,6 +272,8 @@ class Attend(nn.Module):
|
|
272
272
|
causal_mask = self.create_causal_mask(i, j, device = device)
|
273
273
|
dots = dots.masked_fill(causal_mask, mask_value)
|
274
274
|
|
275
|
+
pre_softmax_attn = dots.clone()
|
276
|
+
|
275
277
|
attn = self.attn_fn(dots, dim = -1)
|
276
278
|
attn = attn.type(dtype)
|
277
279
|
|
x_transformers/x_transformers.py
CHANGED
@@ -9,7 +9,7 @@ from functools import partial, wraps
|
|
9
9
|
from inspect import isfunction
|
10
10
|
from collections import namedtuple
|
11
11
|
from dataclasses import dataclass
|
12
|
-
from typing import List, Callable
|
12
|
+
from typing import List, Callable, Optional
|
13
13
|
|
14
14
|
from einops import rearrange, repeat, reduce
|
15
15
|
from einops.layers.torch import Rearrange
|
@@ -23,9 +23,10 @@ DEFAULT_DIM_HEAD = 64
|
|
23
23
|
|
24
24
|
@dataclass
|
25
25
|
class LayerIntermediates:
|
26
|
-
hiddens: List[Tensor] = None
|
27
|
-
attn_intermediates: List[Intermediates] = None
|
28
|
-
layer_hiddens: List[Tensor] = None
|
26
|
+
hiddens: Optional[List[Tensor]] = None
|
27
|
+
attn_intermediates: Optional[List[Intermediates]] = None
|
28
|
+
layer_hiddens: Optional[List[Tensor]] = None
|
29
|
+
attn_z_loss: Optional[Tensor] = None
|
29
30
|
|
30
31
|
# helpers
|
31
32
|
|
@@ -90,6 +91,31 @@ def or_reduce(masks):
|
|
90
91
|
head = head | rest
|
91
92
|
return head
|
92
93
|
|
94
|
+
# auxiliary loss helpers
|
95
|
+
|
96
|
+
def calc_z_loss(
|
97
|
+
pre_softmax_attns: List[Tensor],
|
98
|
+
mask = None,
|
99
|
+
weight = 1.
|
100
|
+
):
|
101
|
+
# the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906
|
102
|
+
# in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects
|
103
|
+
# also used in PaLM as one of the measures
|
104
|
+
|
105
|
+
lse = 0.
|
106
|
+
|
107
|
+
for attn in pre_softmax_attns:
|
108
|
+
lse = lse + attn.logsumexp(dim = -1)
|
109
|
+
|
110
|
+
loss = torch.square(lse)
|
111
|
+
loss = reduce(loss, 'b h n -> b n', 'sum')
|
112
|
+
|
113
|
+
if not exists(mask):
|
114
|
+
return loss.mean() * weight
|
115
|
+
|
116
|
+
loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
|
117
|
+
return loss * weight
|
118
|
+
|
93
119
|
# init helpers
|
94
120
|
|
95
121
|
def init_zero_(layer):
|
@@ -1288,7 +1314,8 @@ class TransformerWrapper(nn.Module):
|
|
1288
1314
|
use_abs_pos_emb = True,
|
1289
1315
|
scaled_sinu_pos_emb = False,
|
1290
1316
|
l2norm_embed = False,
|
1291
|
-
emb_frac_gradient = 1
|
1317
|
+
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
|
1318
|
+
attn_z_loss_weight = 1e-4
|
1292
1319
|
):
|
1293
1320
|
super().__init__()
|
1294
1321
|
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
@@ -1353,10 +1380,12 @@ class TransformerWrapper(nn.Module):
|
|
1353
1380
|
pos = None,
|
1354
1381
|
prepend_embeds = None,
|
1355
1382
|
sum_embeds = None,
|
1383
|
+
return_attn_z_loss = False,
|
1384
|
+
attn_z_loss_weight = 1e-4,
|
1356
1385
|
**kwargs
|
1357
1386
|
):
|
1358
1387
|
b, n, device, num_mem, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.emb_frac_gradient
|
1359
|
-
return_hiddens = return_mems | return_attn | return_intermediates
|
1388
|
+
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
|
1360
1389
|
|
1361
1390
|
# absolute positional embedding
|
1362
1391
|
|
@@ -1419,6 +1448,11 @@ class TransformerWrapper(nn.Module):
|
|
1419
1448
|
else:
|
1420
1449
|
out = self.to_logits(x)
|
1421
1450
|
|
1451
|
+
if return_attn_z_loss:
|
1452
|
+
pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
|
1453
|
+
intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
|
1454
|
+
return_intermediates = True
|
1455
|
+
|
1422
1456
|
if return_intermediates:
|
1423
1457
|
return out, intermediates
|
1424
1458
|
|
@@ -1,12 +1,12 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=DRrcDq3ZwmGfuZ8hjLvKB-IJwPjnnVTGX76ONrlFSA4,11913
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=LuxuSSmZNXivv1XMijOTpNfPd6_uMemptsufL8l1Y7I,4514
|
4
4
|
x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=BQtmUlLNGvaB6Xn5zK03dpO2OQT1jwpEODUTHZJz_ps,56443
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=-CAYjTtqrks8ZTxjYm2stOelZpU4MbZIvLjUxWO0P9Y,3988
|
8
|
-
x_transformers-1.18.
|
9
|
-
x_transformers-1.18.
|
10
|
-
x_transformers-1.18.
|
11
|
-
x_transformers-1.18.
|
12
|
-
x_transformers-1.18.
|
8
|
+
x_transformers-1.18.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
+
x_transformers-1.18.2.dist-info/METADATA,sha256=3F1KCGDFqziINFhFhKLciVog2fnuOhigwAnyG49fb3A,661
|
10
|
+
x_transformers-1.18.2.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
11
|
+
x_transformers-1.18.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
+
x_transformers-1.18.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|