x-transformers 1.18.1__py3-none-any.whl → 1.18.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from functools import partial
2
+ from typing import Optional
2
3
 
3
4
  import torch
4
5
  from torch import nn, einsum, Tensor
@@ -17,9 +18,9 @@ EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash
17
18
 
18
19
  @dataclass
19
20
  class Intermediates:
20
- qk_similarities: Tensor = None
21
- pre_softmax_attn: Tensor = None
22
- post_softmax_attn: Tensor = None
21
+ qk_similarities: Optional[Tensor] = None
22
+ pre_softmax_attn: Optional[Tensor] = None
23
+ post_softmax_attn: Optional[Tensor] = None
23
24
 
24
25
  def to_tuple(self):
25
26
  return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
@@ -256,7 +257,6 @@ class Attend(nn.Module):
256
257
  dots = dots + attn_bias
257
258
 
258
259
  i, j, dtype = *dots.shape[-2:], dots.dtype
259
- pre_softmax_attn = dots.clone()
260
260
 
261
261
  mask_value = -torch.finfo(dots.dtype).max
262
262
 
@@ -272,6 +272,8 @@ class Attend(nn.Module):
272
272
  causal_mask = self.create_causal_mask(i, j, device = device)
273
273
  dots = dots.masked_fill(causal_mask, mask_value)
274
274
 
275
+ pre_softmax_attn = dots.clone()
276
+
275
277
  attn = self.attn_fn(dots, dim = -1)
276
278
  attn = attn.type(dtype)
277
279
 
@@ -9,7 +9,7 @@ from functools import partial, wraps
9
9
  from inspect import isfunction
10
10
  from collections import namedtuple
11
11
  from dataclasses import dataclass
12
- from typing import List, Callable
12
+ from typing import List, Callable, Optional
13
13
 
14
14
  from einops import rearrange, repeat, reduce
15
15
  from einops.layers.torch import Rearrange
@@ -23,9 +23,10 @@ DEFAULT_DIM_HEAD = 64
23
23
 
24
24
  @dataclass
25
25
  class LayerIntermediates:
26
- hiddens: List[Tensor] = None
27
- attn_intermediates: List[Intermediates] = None
28
- layer_hiddens: List[Tensor] = None
26
+ hiddens: Optional[List[Tensor]] = None
27
+ attn_intermediates: Optional[List[Intermediates]] = None
28
+ layer_hiddens: Optional[List[Tensor]] = None
29
+ attn_z_loss: Optional[Tensor] = None
29
30
 
30
31
  # helpers
31
32
 
@@ -90,6 +91,31 @@ def or_reduce(masks):
90
91
  head = head | rest
91
92
  return head
92
93
 
94
+ # auxiliary loss helpers
95
+
96
+ def calc_z_loss(
97
+ pre_softmax_attns: List[Tensor],
98
+ mask = None,
99
+ weight = 1.
100
+ ):
101
+ # the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906
102
+ # in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects
103
+ # also used in PaLM as one of the measures
104
+
105
+ lse = 0.
106
+
107
+ for attn in pre_softmax_attns:
108
+ lse = lse + attn.logsumexp(dim = -1)
109
+
110
+ loss = torch.square(lse)
111
+ loss = reduce(loss, 'b h n -> b n', 'sum')
112
+
113
+ if not exists(mask):
114
+ return loss.mean() * weight
115
+
116
+ loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
117
+ return loss * weight
118
+
93
119
  # init helpers
94
120
 
95
121
  def init_zero_(layer):
@@ -1288,7 +1314,8 @@ class TransformerWrapper(nn.Module):
1288
1314
  use_abs_pos_emb = True,
1289
1315
  scaled_sinu_pos_emb = False,
1290
1316
  l2norm_embed = False,
1291
- emb_frac_gradient = 1. # GLM-130B and Cogview successfully used this, set at 0.1
1317
+ emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
1318
+ attn_z_loss_weight = 1e-4
1292
1319
  ):
1293
1320
  super().__init__()
1294
1321
  assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
@@ -1353,10 +1380,12 @@ class TransformerWrapper(nn.Module):
1353
1380
  pos = None,
1354
1381
  prepend_embeds = None,
1355
1382
  sum_embeds = None,
1383
+ return_attn_z_loss = False,
1384
+ attn_z_loss_weight = 1e-4,
1356
1385
  **kwargs
1357
1386
  ):
1358
1387
  b, n, device, num_mem, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.emb_frac_gradient
1359
- return_hiddens = return_mems | return_attn | return_intermediates
1388
+ return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
1360
1389
 
1361
1390
  # absolute positional embedding
1362
1391
 
@@ -1419,6 +1448,11 @@ class TransformerWrapper(nn.Module):
1419
1448
  else:
1420
1449
  out = self.to_logits(x)
1421
1450
 
1451
+ if return_attn_z_loss:
1452
+ pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
1453
+ intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
1454
+ return_intermediates = True
1455
+
1422
1456
  if return_intermediates:
1423
1457
  return out, intermediates
1424
1458
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.18.1
3
+ Version: 1.18.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,12 +1,12 @@
1
1
  x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
2
- x_transformers/attend.py,sha256=eCgWuKrJJfOPcrQLBJ8_N0FAjJBOWcdGpuGMdY1UfTM,11854
2
+ x_transformers/attend.py,sha256=DRrcDq3ZwmGfuZ8hjLvKB-IJwPjnnVTGX76ONrlFSA4,11913
3
3
  x_transformers/autoregressive_wrapper.py,sha256=LuxuSSmZNXivv1XMijOTpNfPd6_uMemptsufL8l1Y7I,4514
4
4
  x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=8b1fp4ezZ07wLSpR3q7izr-0n_7sKm85WuKxKHdU3bw,55263
6
+ x_transformers/x_transformers.py,sha256=BQtmUlLNGvaB6Xn5zK03dpO2OQT1jwpEODUTHZJz_ps,56443
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=-CAYjTtqrks8ZTxjYm2stOelZpU4MbZIvLjUxWO0P9Y,3988
8
- x_transformers-1.18.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
- x_transformers-1.18.1.dist-info/METADATA,sha256=g1wD6O2VfVdjqP-OKptUt_7O_J4Pt6vcBs7rbrW-TSg,661
10
- x_transformers-1.18.1.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
- x_transformers-1.18.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
- x_transformers-1.18.1.dist-info/RECORD,,
8
+ x_transformers-1.18.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
+ x_transformers-1.18.2.dist-info/METADATA,sha256=3F1KCGDFqziINFhFhKLciVog2fnuOhigwAnyG49fb3A,661
10
+ x_transformers-1.18.2.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
+ x_transformers-1.18.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
+ x_transformers-1.18.2.dist-info/RECORD,,