x-transformers 1.28.0__py3-none-any.whl → 1.28.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,5 @@
1
1
  import math
2
2
  from random import random
3
- from typing import Dict
4
3
  from packaging import version
5
4
 
6
5
  import torch
@@ -11,7 +10,7 @@ from torch.cuda.amp import autocast
11
10
  from functools import partial, wraps
12
11
  from collections import namedtuple
13
12
  from dataclasses import dataclass
14
- from typing import List, Callable, Optional, Union
13
+ from typing import List, Dict, Tuple, Callable, Optional, Union
15
14
 
16
15
  from einops import rearrange, repeat, reduce, pack, unpack
17
16
  from einops.layers.torch import Rearrange
@@ -91,7 +90,10 @@ def l2norm(t, groups = 1):
91
90
  t = F.normalize(t, p = 2, dim = -1)
92
91
  return rearrange(t, '... g d -> ... (g d)')
93
92
 
94
- def pad_at_dim(t, pad, dim = -1, value = 0.):
93
+ def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.):
94
+ if pad == (0, 0):
95
+ return t
96
+
95
97
  dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
96
98
  zeros = ((0, 0) * dims_from_right)
97
99
  return F.pad(t, (*zeros, *pad), value = value)
@@ -816,7 +818,7 @@ class Attention(nn.Module):
816
818
  return_intermediates = False,
817
819
  cache: Optional[Intermediates] = None,
818
820
  ):
819
- b, n, h, kv_h, head_scale, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, x.device, exists(context)
821
+ b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
820
822
 
821
823
  kv_input = default(context, x)
822
824
 
@@ -895,9 +897,7 @@ class Attention(nn.Module):
895
897
 
896
898
  # maybe append memory key / values
897
899
 
898
- has_mem_kv = self.num_mem_kv > 0
899
-
900
- if has_mem_kv:
900
+ if num_mem_kv > 0:
901
901
  mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
902
902
 
903
903
  if self.qk_norm:
@@ -933,6 +933,7 @@ class Attention(nn.Module):
933
933
  range_k = torch.arange(j, device = device)
934
934
  dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
935
935
  max_attend_past_mask = dist > self.max_attend_past
936
+ max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values
936
937
  masks.append(max_attend_past_mask)
937
938
 
938
939
  if len(masks) > 0:
@@ -943,11 +944,7 @@ class Attention(nn.Module):
943
944
  attn_bias = None
944
945
  if exists(rel_pos):
945
946
  attn_bias = rel_pos(i, j)
946
-
947
- # append with no bias for memory key / values
948
-
949
- if has_mem_kv:
950
- attn_bias = pad_at_dim(attn_bias, (self.num_mem_kv, 0), value = 0.)
947
+ attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
951
948
 
952
949
  # attention is all we need
953
950
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.28.0
3
+ Version: 1.28.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
4
4
  x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=GvqVKQZRtIldnSWX4V6qE2sWOGruRvBhk4MVit7ZD_M,63897
7
+ x_transformers/x_transformers.py,sha256=GhhRfzxOQoUAqEeT8VnSAtW7wIJ6aW_5DF4LnsqozdQ,64018
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
10
- x_transformers-1.28.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.28.0.dist-info/METADATA,sha256=o1AbarRMIJY_R0gNaEm5SNUWm3YHEesLL2EEy_Uk6gA,661
12
- x_transformers-1.28.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.28.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.28.0.dist-info/RECORD,,
10
+ x_transformers-1.28.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.28.2.dist-info/METADATA,sha256=oT95hrc_XiI7dMKF9ATWyUwir3cfSfeD1PFTZF2zpy4,661
12
+ x_transformers-1.28.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.28.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.28.2.dist-info/RECORD,,