x-transformers 1.29.1__py3-none-any.whl → 1.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +7 -5
- x_transformers/autoregressive_wrapper.py +6 -4
- x_transformers/continuous.py +2 -2
- x_transformers/x_transformers.py +57 -24
- x_transformers/xval.py +2 -2
- {x_transformers-1.29.1.dist-info → x_transformers-1.30.0.dist-info}/METADATA +3 -3
- x_transformers-1.30.0.dist-info/RECORD +14 -0
- x_transformers-1.29.1.dist-info/RECORD +0 -14
- {x_transformers-1.29.1.dist-info → x_transformers-1.30.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.29.1.dist-info → x_transformers-1.30.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.29.1.dist-info → x_transformers-1.30.0.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from functools import partial
|
2
|
-
from typing import
|
4
|
+
from typing import Tuple
|
3
5
|
|
4
6
|
import torch
|
5
7
|
from torch import nn, einsum, Tensor
|
@@ -16,10 +18,10 @@ from einops import rearrange, repeat
|
|
16
18
|
|
17
19
|
@dataclass
|
18
20
|
class Intermediates:
|
19
|
-
qk_similarities:
|
20
|
-
pre_softmax_attn:
|
21
|
-
post_softmax_attn:
|
22
|
-
cached_kv:
|
21
|
+
qk_similarities: Tensor | None = None
|
22
|
+
pre_softmax_attn: Tensor | None = None
|
23
|
+
post_softmax_attn: Tensor | None = None
|
24
|
+
cached_kv: Tuple[Tensor, Tensor] | None = None
|
23
25
|
|
24
26
|
def to_tuple(self):
|
25
27
|
return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
|
@@ -1,5 +1,7 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from math import ceil, log
|
2
|
-
from typing import
|
4
|
+
from typing import Tuple, Callable
|
3
5
|
|
4
6
|
import torch
|
5
7
|
from torch import nn, Tensor
|
@@ -133,12 +135,12 @@ class AutoregressiveWrapper(Module):
|
|
133
135
|
seq_len,
|
134
136
|
eos_token = None,
|
135
137
|
temperature = 1.,
|
136
|
-
prompt_lens:
|
138
|
+
prompt_lens: Tensor | None = None,
|
137
139
|
filter_logits_fn: Callable = top_k,
|
138
140
|
restrict_to_max_seq_len = True,
|
139
|
-
amateur_model:
|
141
|
+
amateur_model: Module | Tuple[Module] | None = None,
|
140
142
|
filter_kwargs: dict = dict(),
|
141
|
-
contrastive_decode_kwargs:
|
143
|
+
contrastive_decode_kwargs: dict | Tuple[dict] = dict(
|
142
144
|
beta = 0.5,
|
143
145
|
alpha = 0.1
|
144
146
|
),
|
x_transformers/continuous.py
CHANGED
@@ -143,11 +143,11 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
143
143
|
|
144
144
|
if return_mems:
|
145
145
|
hiddens = intermediates.hiddens
|
146
|
-
new_mems =
|
146
|
+
new_mems = tuple(t[..., -self.max_mem_len:, :].detach() for t in hiddens)
|
147
147
|
return out, new_mems
|
148
148
|
|
149
149
|
if return_attn:
|
150
|
-
attn_maps =
|
150
|
+
attn_maps = tuple(t.post_softmax_attn for t in intermediates.attn_intermediates)
|
151
151
|
return out, attn_maps
|
152
152
|
|
153
153
|
return out
|
x_transformers/x_transformers.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import math
|
2
4
|
from random import random
|
3
5
|
from packaging import version
|
@@ -11,7 +13,7 @@ from torch.cuda.amp import autocast
|
|
11
13
|
from functools import partial, wraps
|
12
14
|
from collections import namedtuple
|
13
15
|
from dataclasses import dataclass
|
14
|
-
from typing import List, Dict, Tuple, Callable
|
16
|
+
from typing import List, Dict, Tuple, Callable
|
15
17
|
|
16
18
|
from einops import rearrange, repeat, reduce, pack, unpack
|
17
19
|
from einops.layers.torch import Rearrange
|
@@ -25,13 +27,13 @@ DEFAULT_DIM_HEAD = 64
|
|
25
27
|
|
26
28
|
@dataclass
|
27
29
|
class LayerIntermediates:
|
28
|
-
hiddens:
|
29
|
-
last_hidden:
|
30
|
-
attn_intermediates:
|
31
|
-
layer_hiddens:
|
32
|
-
attn_z_loss:
|
33
|
-
mems:
|
34
|
-
memory_tokens:
|
30
|
+
hiddens: List[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture)
|
31
|
+
last_hidden: Tensor | None = None # very last hidden after all attention layers, after the final norm
|
32
|
+
attn_intermediates: List[Intermediates] | None = None
|
33
|
+
layer_hiddens: List[Tensor] | None = None
|
34
|
+
attn_z_loss: Tensor | None = None
|
35
|
+
mems: Tensor | None = None
|
36
|
+
memory_tokens: Tensor | None = None
|
35
37
|
|
36
38
|
# helpers
|
37
39
|
|
@@ -140,7 +142,7 @@ def init_zero_(layer):
|
|
140
142
|
# keyword argument helpers
|
141
143
|
|
142
144
|
def pick_and_pop(keys, d):
|
143
|
-
values =
|
145
|
+
values = tuple(d.pop(key) for key in keys)
|
144
146
|
return dict(zip(keys, values))
|
145
147
|
|
146
148
|
def group_dict_by_key(cond, d):
|
@@ -149,7 +151,7 @@ def group_dict_by_key(cond, d):
|
|
149
151
|
match = bool(cond(key))
|
150
152
|
ind = int(not match)
|
151
153
|
return_val[ind][key] = d[key]
|
152
|
-
return (
|
154
|
+
return tuple(return_val)
|
153
155
|
|
154
156
|
def string_begins_with(prefix, str):
|
155
157
|
return str.startswith(prefix)
|
@@ -159,7 +161,8 @@ def group_by_key_prefix(prefix, d):
|
|
159
161
|
|
160
162
|
def groupby_prefix_and_trim(prefix, d):
|
161
163
|
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
162
|
-
|
164
|
+
prefix_len = len(prefix)
|
165
|
+
kwargs_without_prefix = {key[prefix_len:]: value for key, value in kwargs_with_prefix.items()}
|
163
166
|
return kwargs_without_prefix, kwargs
|
164
167
|
|
165
168
|
# structured dropout, more effective than traditional attention dropouts
|
@@ -455,7 +458,6 @@ class RotaryEmbedding(Module):
|
|
455
458
|
|
456
459
|
return freqs, scale
|
457
460
|
|
458
|
-
|
459
461
|
def rotate_half(x):
|
460
462
|
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
461
463
|
x1, x2 = x.unbind(dim = -2)
|
@@ -572,8 +574,8 @@ class GRUGating(Module):
|
|
572
574
|
def shift(t, amount, mask = None):
|
573
575
|
if amount == 0:
|
574
576
|
return t
|
575
|
-
|
576
|
-
|
577
|
+
|
578
|
+
amount = min(amount, t.shape[1])
|
577
579
|
|
578
580
|
if exists(mask):
|
579
581
|
t = t.masked_fill(~mask[..., None], 0.)
|
@@ -597,6 +599,23 @@ class ShiftTokens(Module):
|
|
597
599
|
x = torch.cat((*segments_to_shift, *rest), dim = -1)
|
598
600
|
return self.fn(x, **kwargs)
|
599
601
|
|
602
|
+
# post branch operator
|
603
|
+
|
604
|
+
class LayerScale(Module):
|
605
|
+
def __init__(self, fn: Module, dim, init_value = 0.):
|
606
|
+
super().__init__()
|
607
|
+
self.fn = fn
|
608
|
+
self.gamma = nn.Parameter(torch.ones(dim) * init_value)
|
609
|
+
|
610
|
+
def forward(self, x, **kwargs):
|
611
|
+
out = self.fn(x, **kwargs)
|
612
|
+
|
613
|
+
if isinstance(out, Tensor):
|
614
|
+
return out * self.gamma
|
615
|
+
|
616
|
+
out, *rest = out
|
617
|
+
return out * self.gamma, *rest
|
618
|
+
|
600
619
|
# feedforward
|
601
620
|
|
602
621
|
class GLU(Module):
|
@@ -817,7 +836,7 @@ class Attention(Module):
|
|
817
836
|
mem = None,
|
818
837
|
mem_mask = None,
|
819
838
|
return_intermediates = False,
|
820
|
-
cache:
|
839
|
+
cache: Intermediates | None = None,
|
821
840
|
):
|
822
841
|
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
|
823
842
|
|
@@ -1024,11 +1043,11 @@ class AttentionLayers(Module):
|
|
1024
1043
|
rotary_interpolation_factor = 1.,
|
1025
1044
|
rotary_xpos_scale_base = 512,
|
1026
1045
|
rotary_base_rescale_factor = 1.,
|
1027
|
-
|
1046
|
+
weight_tie_layers = False,
|
1047
|
+
custom_layers: Tuple[str] | None = None,
|
1048
|
+
layers_execute_order: Tuple[int] | None = None,
|
1028
1049
|
sandwich_coef = None,
|
1029
1050
|
par_ratio = None,
|
1030
|
-
weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
|
1031
|
-
layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
|
1032
1051
|
residual_attn = False,
|
1033
1052
|
cross_residual_attn = False,
|
1034
1053
|
macaron = False,
|
@@ -1045,6 +1064,8 @@ class AttentionLayers(Module):
|
|
1045
1064
|
layer_dropout = 0.,
|
1046
1065
|
cross_attn_tokens_dropout = 0.,
|
1047
1066
|
disable_abs_pos_emb = None,
|
1067
|
+
use_layerscale = False,
|
1068
|
+
layerscale_init_value = 0.,
|
1048
1069
|
**kwargs
|
1049
1070
|
):
|
1050
1071
|
super().__init__()
|
@@ -1108,6 +1129,8 @@ class AttentionLayers(Module):
|
|
1108
1129
|
|
1109
1130
|
self.cross_attend = cross_attend
|
1110
1131
|
|
1132
|
+
# determine norm
|
1133
|
+
|
1111
1134
|
assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
|
1112
1135
|
|
1113
1136
|
if use_scalenorm:
|
@@ -1121,6 +1144,8 @@ class AttentionLayers(Module):
|
|
1121
1144
|
|
1122
1145
|
norm_fn = partial(norm_class, dim)
|
1123
1146
|
|
1147
|
+
# determine default block layer type order
|
1148
|
+
|
1124
1149
|
if cross_attend and not only_cross:
|
1125
1150
|
default_block = ('a', 'c', 'f')
|
1126
1151
|
elif cross_attend and only_cross:
|
@@ -1131,6 +1156,13 @@ class AttentionLayers(Module):
|
|
1131
1156
|
if macaron:
|
1132
1157
|
default_block = ('f',) + default_block
|
1133
1158
|
|
1159
|
+
# determine post branch wrapper
|
1160
|
+
|
1161
|
+
post_branch_fn = None
|
1162
|
+
|
1163
|
+
if use_layerscale:
|
1164
|
+
post_branch_fn = partial(LayerScale, dim = dim, init_value = layerscale_init_value)
|
1165
|
+
|
1134
1166
|
# zero init
|
1135
1167
|
|
1136
1168
|
if zero_init_branch_output:
|
@@ -1178,11 +1210,9 @@ class AttentionLayers(Module):
|
|
1178
1210
|
|
1179
1211
|
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
1180
1212
|
|
1181
|
-
#
|
1213
|
+
# set the depth
|
1182
1214
|
|
1183
1215
|
depth = default(depth, len(self.layers_execute_order))
|
1184
|
-
assert depth == len(self.layers_execute_order)
|
1185
|
-
|
1186
1216
|
self.depth = depth
|
1187
1217
|
|
1188
1218
|
# stochastic depth
|
@@ -1221,6 +1251,9 @@ class AttentionLayers(Module):
|
|
1221
1251
|
shift_range_lower = -layer_shift_tokens if not causal else 0
|
1222
1252
|
layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
|
1223
1253
|
|
1254
|
+
if exists(post_branch_fn):
|
1255
|
+
layer = post_branch_fn(layer)
|
1256
|
+
|
1224
1257
|
residual_fn = GRUGating if gate_residual else Residual
|
1225
1258
|
residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
|
1226
1259
|
|
@@ -1250,8 +1283,8 @@ class AttentionLayers(Module):
|
|
1250
1283
|
self_attn_kv_mask = None,
|
1251
1284
|
mems = None,
|
1252
1285
|
mem_masks = None,
|
1253
|
-
seq_start_pos:
|
1254
|
-
cache:
|
1286
|
+
seq_start_pos: Tensor | None = None,
|
1287
|
+
cache: LayerIntermediates | None = None,
|
1255
1288
|
cache_age = 1,
|
1256
1289
|
return_hiddens = False,
|
1257
1290
|
rotary_pos_emb = None
|
@@ -1643,7 +1676,7 @@ class TransformerWrapper(Module):
|
|
1643
1676
|
return_attn_z_loss = False,
|
1644
1677
|
attn_z_loss_weight = 1e-4,
|
1645
1678
|
seq_start_pos = None,
|
1646
|
-
cache:
|
1679
|
+
cache: LayerIntermediates | None = None,
|
1647
1680
|
**kwargs
|
1648
1681
|
):
|
1649
1682
|
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
|
x_transformers/xval.py
CHANGED
@@ -176,11 +176,11 @@ class XValTransformerWrapper(nn.Module):
|
|
176
176
|
|
177
177
|
if return_mems:
|
178
178
|
hiddens = intermediates.hiddens
|
179
|
-
new_mems =
|
179
|
+
new_mems = tuple(t[..., -self.max_mem_len:, :].detach() for t in hiddens)
|
180
180
|
return out, new_mems
|
181
181
|
|
182
182
|
if return_attn:
|
183
|
-
attn_maps =
|
183
|
+
attn_maps = tuple(t.post_softmax_attn for t in intermediates.attn_intermediates)
|
184
184
|
return out, attn_maps
|
185
185
|
|
186
186
|
return out
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.30.0
|
4
4
|
Summary: X-Transformers - Pytorch
|
5
5
|
Home-page: https://github.com/lucidrains/x-transformers
|
6
6
|
Author: Phil Wang
|
@@ -14,6 +14,6 @@ Classifier: License :: OSI Approved :: MIT License
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.6
|
15
15
|
Description-Content-Type: text/markdown
|
16
16
|
License-File: LICENSE
|
17
|
-
Requires-Dist: torch >=
|
18
|
-
Requires-Dist: einops >=0.
|
17
|
+
Requires-Dist: torch >=2.0
|
18
|
+
Requires-Dist: einops >=0.8.0
|
19
19
|
|
@@ -0,0 +1,14 @@
|
|
1
|
+
x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
|
2
|
+
x_transformers/attend.py,sha256=Y9eE26I7BM8rGveabhiRhzw_xq9TY61Sp10QC1hX2O8,10192
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
|
+
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
|
+
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
|
+
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
+
x_transformers/x_transformers.py,sha256=O-W-xnDhh4t3CKWDhmMJNobbIPQ2Vg24CJC8JBDou2M,65970
|
8
|
+
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
|
+
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
+
x_transformers-1.30.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.0.dist-info/METADATA,sha256=Pd-ClX7tybO-h7mfYo2B8ncQM3jU3nQjf2yG1QBpORw,661
|
12
|
+
x_transformers-1.30.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.0.dist-info/RECORD,,
|
@@ -1,14 +0,0 @@
|
|
1
|
-
x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
|
2
|
-
x_transformers/attend.py,sha256=L7vctHJ0PnECohu4cUu8yvY8cUrVyJxHmMFR0RGL0z4,10163
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
|
4
|
-
x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
|
5
|
-
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
|
-
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=jj87ALpQpHGgvG1oHn4Z6UDmc1pqkoO6dY7YtY038w8,65269
|
8
|
-
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
|
-
x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
|
10
|
-
x_transformers-1.29.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
-
x_transformers-1.29.1.dist-info/METADATA,sha256=4Nnxc5THUI-d21Szj2mPLTlZYF0A9xVjHN4laFiLCIE,661
|
12
|
-
x_transformers-1.29.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
-
x_transformers-1.29.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
-
x_transformers-1.29.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|