x-transformers 1.28.1__py3-none-any.whl → 1.28.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/__init__.py +2 -0
- x_transformers/x_transformers.py +9 -12
- {x_transformers-1.28.1.dist-info → x_transformers-1.28.4.dist-info}/METADATA +1 -1
- {x_transformers-1.28.1.dist-info → x_transformers-1.28.4.dist-info}/RECORD +7 -7
- {x_transformers-1.28.1.dist-info → x_transformers-1.28.4.dist-info}/LICENSE +0 -0
- {x_transformers-1.28.1.dist-info → x_transformers-1.28.4.dist-info}/WHEEL +0 -0
- {x_transformers-1.28.1.dist-info → x_transformers-1.28.4.dist-info}/top_level.txt +0 -0
x_transformers/__init__.py
CHANGED
x_transformers/x_transformers.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
import math
|
2
2
|
from random import random
|
3
|
-
from typing import Dict
|
4
3
|
from packaging import version
|
5
4
|
|
6
5
|
import torch
|
@@ -11,7 +10,7 @@ from torch.cuda.amp import autocast
|
|
11
10
|
from functools import partial, wraps
|
12
11
|
from collections import namedtuple
|
13
12
|
from dataclasses import dataclass
|
14
|
-
from typing import List, Callable, Optional, Union
|
13
|
+
from typing import List, Dict, Tuple, Callable, Optional, Union
|
15
14
|
|
16
15
|
from einops import rearrange, repeat, reduce, pack, unpack
|
17
16
|
from einops.layers.torch import Rearrange
|
@@ -91,7 +90,10 @@ def l2norm(t, groups = 1):
|
|
91
90
|
t = F.normalize(t, p = 2, dim = -1)
|
92
91
|
return rearrange(t, '... g d -> ... (g d)')
|
93
92
|
|
94
|
-
def pad_at_dim(t, pad, dim = -1, value = 0.):
|
93
|
+
def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.):
|
94
|
+
if pad == (0, 0):
|
95
|
+
return t
|
96
|
+
|
95
97
|
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
96
98
|
zeros = ((0, 0) * dims_from_right)
|
97
99
|
return F.pad(t, (*zeros, *pad), value = value)
|
@@ -816,7 +818,7 @@ class Attention(nn.Module):
|
|
816
818
|
return_intermediates = False,
|
817
819
|
cache: Optional[Intermediates] = None,
|
818
820
|
):
|
819
|
-
b, n, h, kv_h, head_scale, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, x.device, exists(context)
|
821
|
+
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
|
820
822
|
|
821
823
|
kv_input = default(context, x)
|
822
824
|
|
@@ -895,9 +897,7 @@ class Attention(nn.Module):
|
|
895
897
|
|
896
898
|
# maybe append memory key / values
|
897
899
|
|
898
|
-
|
899
|
-
|
900
|
-
if has_mem_kv:
|
900
|
+
if num_mem_kv > 0:
|
901
901
|
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
|
902
902
|
|
903
903
|
if self.qk_norm:
|
@@ -933,6 +933,7 @@ class Attention(nn.Module):
|
|
933
933
|
range_k = torch.arange(j, device = device)
|
934
934
|
dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
|
935
935
|
max_attend_past_mask = dist > self.max_attend_past
|
936
|
+
max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values
|
936
937
|
masks.append(max_attend_past_mask)
|
937
938
|
|
938
939
|
if len(masks) > 0:
|
@@ -943,11 +944,7 @@ class Attention(nn.Module):
|
|
943
944
|
attn_bias = None
|
944
945
|
if exists(rel_pos):
|
945
946
|
attn_bias = rel_pos(i, j)
|
946
|
-
|
947
|
-
# append with no bias for memory key / values
|
948
|
-
|
949
|
-
if exists(attn_bias) and has_mem_kv:
|
950
|
-
attn_bias = pad_at_dim(attn_bias, (self.num_mem_kv, 0), value = 0.)
|
947
|
+
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
|
951
948
|
|
952
949
|
# attention is all we need
|
953
950
|
|
@@ -1,14 +1,14 @@
|
|
1
|
-
x_transformers/__init__.py,sha256=
|
1
|
+
x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
|
2
2
|
x_transformers/attend.py,sha256=L7vctHJ0PnECohu4cUu8yvY8cUrVyJxHmMFR0RGL0z4,10163
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
|
4
4
|
x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=GhhRfzxOQoUAqEeT8VnSAtW7wIJ6aW_5DF4LnsqozdQ,64018
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
|
10
|
-
x_transformers-1.28.
|
11
|
-
x_transformers-1.28.
|
12
|
-
x_transformers-1.28.
|
13
|
-
x_transformers-1.28.
|
14
|
-
x_transformers-1.28.
|
10
|
+
x_transformers-1.28.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.28.4.dist-info/METADATA,sha256=JKCQN6QEaSe9M63vpez9hdan0f67zSiu5okyl9GDNKU,661
|
12
|
+
x_transformers-1.28.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.28.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.28.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|