x-transformers 1.18.0__py3-none-any.whl → 1.18.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from functools import partial
2
+ from typing import Optional
2
3
 
3
4
  import torch
4
5
  from torch import nn, einsum, Tensor
@@ -17,9 +18,9 @@ EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash
17
18
 
18
19
  @dataclass
19
20
  class Intermediates:
20
- qk_similarities: Tensor = None
21
- pre_softmax_attn: Tensor = None
22
- post_softmax_attn: Tensor = None
21
+ qk_similarities: Optional[Tensor] = None
22
+ pre_softmax_attn: Optional[Tensor] = None
23
+ post_softmax_attn: Optional[Tensor] = None
23
24
 
24
25
  def to_tuple(self):
25
26
  return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
@@ -256,7 +257,6 @@ class Attend(nn.Module):
256
257
  dots = dots + attn_bias
257
258
 
258
259
  i, j, dtype = *dots.shape[-2:], dots.dtype
259
- pre_softmax_attn = dots.clone()
260
260
 
261
261
  mask_value = -torch.finfo(dots.dtype).max
262
262
 
@@ -272,6 +272,8 @@ class Attend(nn.Module):
272
272
  causal_mask = self.create_causal_mask(i, j, device = device)
273
273
  dots = dots.masked_fill(causal_mask, mask_value)
274
274
 
275
+ pre_softmax_attn = dots.clone()
276
+
275
277
  attn = self.attn_fn(dots, dim = -1)
276
278
  attn = attn.type(dtype)
277
279
 
@@ -9,7 +9,7 @@ from functools import partial, wraps
9
9
  from inspect import isfunction
10
10
  from collections import namedtuple
11
11
  from dataclasses import dataclass
12
- from typing import List, Callable
12
+ from typing import List, Callable, Optional
13
13
 
14
14
  from einops import rearrange, repeat, reduce
15
15
  from einops.layers.torch import Rearrange
@@ -23,8 +23,10 @@ DEFAULT_DIM_HEAD = 64
23
23
 
24
24
  @dataclass
25
25
  class LayerIntermediates:
26
- hiddens: List[Tensor] = None
27
- attn_intermediates: List[Intermediates] = None
26
+ hiddens: Optional[List[Tensor]] = None
27
+ attn_intermediates: Optional[List[Intermediates]] = None
28
+ layer_hiddens: Optional[List[Tensor]] = None
29
+ attn_z_loss: Optional[Tensor] = None
28
30
 
29
31
  # helpers
30
32
 
@@ -89,6 +91,31 @@ def or_reduce(masks):
89
91
  head = head | rest
90
92
  return head
91
93
 
94
+ # auxiliary loss helpers
95
+
96
+ def calc_z_loss(
97
+ pre_softmax_attns: List[Tensor],
98
+ mask = None,
99
+ weight = 1.
100
+ ):
101
+ # the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906
102
+ # in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects
103
+ # also used in PaLM as one of the measures
104
+
105
+ lse = 0.
106
+
107
+ for attn in pre_softmax_attns:
108
+ lse = lse + attn.logsumexp(dim = -1)
109
+
110
+ loss = torch.square(lse)
111
+ loss = reduce(loss, 'b h n -> b n', 'sum')
112
+
113
+ if not exists(mask):
114
+ return loss.mean() * weight
115
+
116
+ loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
117
+ return loss * weight
118
+
92
119
  # init helpers
93
120
 
94
121
  def init_zero_(layer):
@@ -373,29 +400,6 @@ class AlibiPositionalBias(nn.Module):
373
400
 
374
401
  return self.bias
375
402
 
376
- class LearnedAlibiPositionalBias(AlibiPositionalBias):
377
- def __init__(self, heads, total_heads):
378
- super().__init__(heads, total_heads)
379
- log_slopes = torch.log(self.slopes)
380
- self.learned_logslopes = nn.Parameter(log_slopes)
381
-
382
- def forward(self, i, j):
383
- h, device = self.heads, self.device
384
-
385
- def get_slopes(param):
386
- return pad_at_dim(param.exp(), (0, h - param.shape[0]), dim = -2)
387
-
388
- if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
389
- bias = self.bias[..., :i, :j]
390
- else:
391
- bias = self.get_bias(i, j, device)
392
- self.register_buffer('bias', bias, persistent = False)
393
-
394
- slopes = get_slopes(self.learned_logslopes)
395
- bias = bias * slopes
396
-
397
- return bias
398
-
399
403
  class RotaryEmbedding(nn.Module):
400
404
  def __init__(
401
405
  self,
@@ -907,7 +911,6 @@ class AttentionLayers(nn.Module):
907
911
  use_simple_rmsnorm = False,
908
912
  alibi_pos_bias = False,
909
913
  alibi_num_heads = None,
910
- alibi_learned = False,
911
914
  rel_pos_bias = False,
912
915
  rel_pos_num_buckets = 32,
913
916
  rel_pos_max_distance = 128,
@@ -979,8 +982,7 @@ class AttentionLayers(nn.Module):
979
982
  elif alibi_pos_bias:
980
983
  alibi_num_heads = default(alibi_num_heads, heads)
981
984
  assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
982
- alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned else AlibiPositionalBias
983
- self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, total_heads = heads)
985
+ self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
984
986
 
985
987
  # determine deepnorm and residual scale
986
988
 
@@ -1135,7 +1137,9 @@ class AttentionLayers(nn.Module):
1135
1137
  assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1136
1138
 
1137
1139
  hiddens = []
1140
+ layer_hiddens = []
1138
1141
  intermediates = []
1142
+
1139
1143
  prev_attn = None
1140
1144
  prev_cross_attn = None
1141
1145
 
@@ -1165,6 +1169,9 @@ class AttentionLayers(nn.Module):
1165
1169
 
1166
1170
  inner_residual = x
1167
1171
 
1172
+ if return_hiddens:
1173
+ layer_hiddens.append(x)
1174
+
1168
1175
  pre_norm, post_branch_norm, post_main_norm = norm
1169
1176
 
1170
1177
  if exists(pre_norm):
@@ -1196,6 +1203,9 @@ class AttentionLayers(nn.Module):
1196
1203
  if exists(post_main_norm):
1197
1204
  x = post_main_norm(x)
1198
1205
 
1206
+ if return_hiddens:
1207
+ layer_hiddens.append(x)
1208
+
1199
1209
  if self.resi_dual:
1200
1210
  x = x + self.final_norm(outer_residual)
1201
1211
  else:
@@ -1204,7 +1214,8 @@ class AttentionLayers(nn.Module):
1204
1214
  if return_hiddens:
1205
1215
  intermediates = LayerIntermediates(
1206
1216
  hiddens = hiddens,
1207
- attn_intermediates = intermediates
1217
+ attn_intermediates = intermediates,
1218
+ layer_hiddens = layer_hiddens
1208
1219
  )
1209
1220
 
1210
1221
  return x, intermediates
@@ -1303,7 +1314,8 @@ class TransformerWrapper(nn.Module):
1303
1314
  use_abs_pos_emb = True,
1304
1315
  scaled_sinu_pos_emb = False,
1305
1316
  l2norm_embed = False,
1306
- emb_frac_gradient = 1. # GLM-130B and Cogview successfully used this, set at 0.1
1317
+ emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
1318
+ attn_z_loss_weight = 1e-4
1307
1319
  ):
1308
1320
  super().__init__()
1309
1321
  assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
@@ -1368,10 +1380,12 @@ class TransformerWrapper(nn.Module):
1368
1380
  pos = None,
1369
1381
  prepend_embeds = None,
1370
1382
  sum_embeds = None,
1383
+ return_attn_z_loss = False,
1384
+ attn_z_loss_weight = 1e-4,
1371
1385
  **kwargs
1372
1386
  ):
1373
1387
  b, n, device, num_mem, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.emb_frac_gradient
1374
- return_hiddens = return_mems | return_attn
1388
+ return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
1375
1389
 
1376
1390
  # absolute positional embedding
1377
1391
 
@@ -1434,6 +1448,11 @@ class TransformerWrapper(nn.Module):
1434
1448
  else:
1435
1449
  out = self.to_logits(x)
1436
1450
 
1451
+ if return_attn_z_loss:
1452
+ pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
1453
+ intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
1454
+ return_intermediates = True
1455
+
1437
1456
  if return_intermediates:
1438
1457
  return out, intermediates
1439
1458
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.18.0
3
+ Version: 1.18.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,12 +1,12 @@
1
1
  x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
2
- x_transformers/attend.py,sha256=eCgWuKrJJfOPcrQLBJ8_N0FAjJBOWcdGpuGMdY1UfTM,11854
2
+ x_transformers/attend.py,sha256=DRrcDq3ZwmGfuZ8hjLvKB-IJwPjnnVTGX76ONrlFSA4,11913
3
3
  x_transformers/autoregressive_wrapper.py,sha256=LuxuSSmZNXivv1XMijOTpNfPd6_uMemptsufL8l1Y7I,4514
4
4
  x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=2SYRnF4VgOUR5NH1phvRw9S5_s_QCgpSCQ8Srp4dBvA,55909
6
+ x_transformers/x_transformers.py,sha256=BQtmUlLNGvaB6Xn5zK03dpO2OQT1jwpEODUTHZJz_ps,56443
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=-CAYjTtqrks8ZTxjYm2stOelZpU4MbZIvLjUxWO0P9Y,3988
8
- x_transformers-1.18.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
- x_transformers-1.18.0.dist-info/METADATA,sha256=gcnxC2WFpo6zo6yKFRhSatYJWYLtkWxiDcvDgkAsVh0,661
10
- x_transformers-1.18.0.dist-info/WHEEL,sha256=5sUXSg9e4bi7lTLOHcm6QEYwO5TIF1TNbTSVFVjcJcc,92
11
- x_transformers-1.18.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
- x_transformers-1.18.0.dist-info/RECORD,,
8
+ x_transformers-1.18.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
+ x_transformers-1.18.2.dist-info/METADATA,sha256=3F1KCGDFqziINFhFhKLciVog2fnuOhigwAnyG49fb3A,661
10
+ x_transformers-1.18.2.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
+ x_transformers-1.18.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
+ x_transformers-1.18.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.41.1)
2
+ Generator: bdist_wheel (0.41.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5