x-transformers 1.29.2__py3-none-any.whl → 1.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +7 -5
- x_transformers/autoregressive_wrapper.py +6 -4
- x_transformers/continuous.py +2 -2
- x_transformers/x_transformers.py +65 -27
- x_transformers/xval.py +2 -2
- {x_transformers-1.29.2.dist-info → x_transformers-1.30.1.dist-info}/METADATA +3 -3
- x_transformers-1.30.1.dist-info/RECORD +14 -0
- x_transformers-1.29.2.dist-info/RECORD +0 -14
- {x_transformers-1.29.2.dist-info → x_transformers-1.30.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.29.2.dist-info → x_transformers-1.30.1.dist-info}/WHEEL +0 -0
- {x_transformers-1.29.2.dist-info → x_transformers-1.30.1.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from functools import partial
|
2
|
-
from typing import
|
4
|
+
from typing import Tuple
|
3
5
|
|
4
6
|
import torch
|
5
7
|
from torch import nn, einsum, Tensor
|
@@ -16,10 +18,10 @@ from einops import rearrange, repeat
|
|
16
18
|
|
17
19
|
@dataclass
|
18
20
|
class Intermediates:
|
19
|
-
qk_similarities:
|
20
|
-
pre_softmax_attn:
|
21
|
-
post_softmax_attn:
|
22
|
-
cached_kv:
|
21
|
+
qk_similarities: Tensor | None = None
|
22
|
+
pre_softmax_attn: Tensor | None = None
|
23
|
+
post_softmax_attn: Tensor | None = None
|
24
|
+
cached_kv: Tuple[Tensor, Tensor] | None = None
|
23
25
|
|
24
26
|
def to_tuple(self):
|
25
27
|
return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
|
@@ -1,5 +1,7 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from math import ceil, log
|
2
|
-
from typing import
|
4
|
+
from typing import Tuple, Callable
|
3
5
|
|
4
6
|
import torch
|
5
7
|
from torch import nn, Tensor
|
@@ -133,12 +135,12 @@ class AutoregressiveWrapper(Module):
|
|
133
135
|
seq_len,
|
134
136
|
eos_token = None,
|
135
137
|
temperature = 1.,
|
136
|
-
prompt_lens:
|
138
|
+
prompt_lens: Tensor | None = None,
|
137
139
|
filter_logits_fn: Callable = top_k,
|
138
140
|
restrict_to_max_seq_len = True,
|
139
|
-
amateur_model:
|
141
|
+
amateur_model: Module | Tuple[Module] | None = None,
|
140
142
|
filter_kwargs: dict = dict(),
|
141
|
-
contrastive_decode_kwargs:
|
143
|
+
contrastive_decode_kwargs: dict | Tuple[dict] = dict(
|
142
144
|
beta = 0.5,
|
143
145
|
alpha = 0.1
|
144
146
|
),
|
x_transformers/continuous.py
CHANGED
@@ -143,11 +143,11 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
143
143
|
|
144
144
|
if return_mems:
|
145
145
|
hiddens = intermediates.hiddens
|
146
|
-
new_mems =
|
146
|
+
new_mems = tuple(t[..., -self.max_mem_len:, :].detach() for t in hiddens)
|
147
147
|
return out, new_mems
|
148
148
|
|
149
149
|
if return_attn:
|
150
|
-
attn_maps =
|
150
|
+
attn_maps = tuple(t.post_softmax_attn for t in intermediates.attn_intermediates)
|
151
151
|
return out, attn_maps
|
152
152
|
|
153
153
|
return out
|
x_transformers/x_transformers.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import math
|
2
4
|
from random import random
|
3
5
|
from packaging import version
|
@@ -11,7 +13,7 @@ from torch.cuda.amp import autocast
|
|
11
13
|
from functools import partial, wraps
|
12
14
|
from collections import namedtuple
|
13
15
|
from dataclasses import dataclass
|
14
|
-
from typing import List, Dict, Tuple, Callable
|
16
|
+
from typing import List, Dict, Tuple, Callable
|
15
17
|
|
16
18
|
from einops import rearrange, repeat, reduce, pack, unpack
|
17
19
|
from einops.layers.torch import Rearrange
|
@@ -25,13 +27,13 @@ DEFAULT_DIM_HEAD = 64
|
|
25
27
|
|
26
28
|
@dataclass
|
27
29
|
class LayerIntermediates:
|
28
|
-
hiddens:
|
29
|
-
last_hidden:
|
30
|
-
attn_intermediates:
|
31
|
-
layer_hiddens:
|
32
|
-
attn_z_loss:
|
33
|
-
mems:
|
34
|
-
memory_tokens:
|
30
|
+
hiddens: List[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture)
|
31
|
+
last_hidden: Tensor | None = None # very last hidden after all attention layers, after the final norm
|
32
|
+
attn_intermediates: List[Intermediates] | None = None
|
33
|
+
layer_hiddens: List[Tensor] | None = None
|
34
|
+
attn_z_loss: Tensor | None = None
|
35
|
+
mems: Tensor | None = None
|
36
|
+
memory_tokens: Tensor | None = None
|
35
37
|
|
36
38
|
# helpers
|
37
39
|
|
@@ -140,7 +142,7 @@ def init_zero_(layer):
|
|
140
142
|
# keyword argument helpers
|
141
143
|
|
142
144
|
def pick_and_pop(keys, d):
|
143
|
-
values =
|
145
|
+
values = tuple(d.pop(key) for key in keys)
|
144
146
|
return dict(zip(keys, values))
|
145
147
|
|
146
148
|
def group_dict_by_key(cond, d):
|
@@ -149,7 +151,7 @@ def group_dict_by_key(cond, d):
|
|
149
151
|
match = bool(cond(key))
|
150
152
|
ind = int(not match)
|
151
153
|
return_val[ind][key] = d[key]
|
152
|
-
return (
|
154
|
+
return tuple(return_val)
|
153
155
|
|
154
156
|
def string_begins_with(prefix, str):
|
155
157
|
return str.startswith(prefix)
|
@@ -159,7 +161,8 @@ def group_by_key_prefix(prefix, d):
|
|
159
161
|
|
160
162
|
def groupby_prefix_and_trim(prefix, d):
|
161
163
|
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
162
|
-
|
164
|
+
prefix_len = len(prefix)
|
165
|
+
kwargs_without_prefix = {key[prefix_len:]: value for key, value in kwargs_with_prefix.items()}
|
163
166
|
return kwargs_without_prefix, kwargs
|
164
167
|
|
165
168
|
# structured dropout, more effective than traditional attention dropouts
|
@@ -441,25 +444,27 @@ class RotaryEmbedding(Module):
|
|
441
444
|
|
442
445
|
@autocast(enabled = False)
|
443
446
|
def forward(self, t):
|
444
|
-
max_pos = t.max()+1
|
447
|
+
max_pos = t.max() + 1
|
445
448
|
|
446
449
|
freqs = torch.einsum('i , j -> i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
|
447
|
-
freqs = torch.
|
450
|
+
freqs = torch.stack((freqs, freqs), dim = -1)
|
451
|
+
freqs = rearrange(freqs, '... d r -> ... (d r)')
|
448
452
|
|
449
453
|
if not exists(self.scale):
|
450
454
|
return freqs, 1.
|
451
455
|
|
452
456
|
power = (t - (max_pos // 2)) / self.scale_base
|
453
457
|
scale = self.scale ** rearrange(power, 'n -> n 1')
|
454
|
-
scale = torch.
|
458
|
+
scale = torch.stack((scale, scale), dim = -1)
|
459
|
+
scale = rearrange(scale, '... d r -> ... (d r)')
|
455
460
|
|
456
461
|
return freqs, scale
|
457
462
|
|
458
|
-
|
459
463
|
def rotate_half(x):
|
460
|
-
x = rearrange(x, '... (
|
461
|
-
x1, x2 = x.unbind(dim = -
|
462
|
-
|
464
|
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
465
|
+
x1, x2 = x.unbind(dim = -1)
|
466
|
+
x = torch.stack((-x2, x1), dim = -1)
|
467
|
+
return rearrange(x, '... d r -> ... (d r)')
|
463
468
|
|
464
469
|
@autocast(enabled = False)
|
465
470
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
@@ -572,8 +577,8 @@ class GRUGating(Module):
|
|
572
577
|
def shift(t, amount, mask = None):
|
573
578
|
if amount == 0:
|
574
579
|
return t
|
575
|
-
|
576
|
-
|
580
|
+
|
581
|
+
amount = min(amount, t.shape[1])
|
577
582
|
|
578
583
|
if exists(mask):
|
579
584
|
t = t.masked_fill(~mask[..., None], 0.)
|
@@ -597,6 +602,23 @@ class ShiftTokens(Module):
|
|
597
602
|
x = torch.cat((*segments_to_shift, *rest), dim = -1)
|
598
603
|
return self.fn(x, **kwargs)
|
599
604
|
|
605
|
+
# post branch operator
|
606
|
+
|
607
|
+
class LayerScale(Module):
|
608
|
+
def __init__(self, fn: Module, dim, init_value = 0.):
|
609
|
+
super().__init__()
|
610
|
+
self.fn = fn
|
611
|
+
self.gamma = nn.Parameter(torch.ones(dim) * init_value)
|
612
|
+
|
613
|
+
def forward(self, x, **kwargs):
|
614
|
+
out = self.fn(x, **kwargs)
|
615
|
+
|
616
|
+
if isinstance(out, Tensor):
|
617
|
+
return out * self.gamma
|
618
|
+
|
619
|
+
out, *rest = out
|
620
|
+
return out * self.gamma, *rest
|
621
|
+
|
600
622
|
# feedforward
|
601
623
|
|
602
624
|
class GLU(Module):
|
@@ -817,7 +839,7 @@ class Attention(Module):
|
|
817
839
|
mem = None,
|
818
840
|
mem_mask = None,
|
819
841
|
return_intermediates = False,
|
820
|
-
cache:
|
842
|
+
cache: Intermediates | None = None,
|
821
843
|
):
|
822
844
|
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
|
823
845
|
|
@@ -1024,11 +1046,11 @@ class AttentionLayers(Module):
|
|
1024
1046
|
rotary_interpolation_factor = 1.,
|
1025
1047
|
rotary_xpos_scale_base = 512,
|
1026
1048
|
rotary_base_rescale_factor = 1.,
|
1027
|
-
|
1049
|
+
weight_tie_layers = False,
|
1050
|
+
custom_layers: Tuple[str] | None = None,
|
1051
|
+
layers_execute_order: Tuple[int] | None = None,
|
1028
1052
|
sandwich_coef = None,
|
1029
1053
|
par_ratio = None,
|
1030
|
-
weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
|
1031
|
-
layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
|
1032
1054
|
residual_attn = False,
|
1033
1055
|
cross_residual_attn = False,
|
1034
1056
|
macaron = False,
|
@@ -1045,6 +1067,8 @@ class AttentionLayers(Module):
|
|
1045
1067
|
layer_dropout = 0.,
|
1046
1068
|
cross_attn_tokens_dropout = 0.,
|
1047
1069
|
disable_abs_pos_emb = None,
|
1070
|
+
use_layerscale = False,
|
1071
|
+
layerscale_init_value = 0.,
|
1048
1072
|
**kwargs
|
1049
1073
|
):
|
1050
1074
|
super().__init__()
|
@@ -1108,6 +1132,8 @@ class AttentionLayers(Module):
|
|
1108
1132
|
|
1109
1133
|
self.cross_attend = cross_attend
|
1110
1134
|
|
1135
|
+
# determine norm
|
1136
|
+
|
1111
1137
|
assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
|
1112
1138
|
|
1113
1139
|
if use_scalenorm:
|
@@ -1121,6 +1147,8 @@ class AttentionLayers(Module):
|
|
1121
1147
|
|
1122
1148
|
norm_fn = partial(norm_class, dim)
|
1123
1149
|
|
1150
|
+
# determine default block layer type order
|
1151
|
+
|
1124
1152
|
if cross_attend and not only_cross:
|
1125
1153
|
default_block = ('a', 'c', 'f')
|
1126
1154
|
elif cross_attend and only_cross:
|
@@ -1131,6 +1159,13 @@ class AttentionLayers(Module):
|
|
1131
1159
|
if macaron:
|
1132
1160
|
default_block = ('f',) + default_block
|
1133
1161
|
|
1162
|
+
# determine post branch wrapper
|
1163
|
+
|
1164
|
+
post_branch_fn = None
|
1165
|
+
|
1166
|
+
if use_layerscale:
|
1167
|
+
post_branch_fn = partial(LayerScale, dim = dim, init_value = layerscale_init_value)
|
1168
|
+
|
1134
1169
|
# zero init
|
1135
1170
|
|
1136
1171
|
if zero_init_branch_output:
|
@@ -1219,6 +1254,9 @@ class AttentionLayers(Module):
|
|
1219
1254
|
shift_range_lower = -layer_shift_tokens if not causal else 0
|
1220
1255
|
layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
|
1221
1256
|
|
1257
|
+
if exists(post_branch_fn):
|
1258
|
+
layer = post_branch_fn(layer)
|
1259
|
+
|
1222
1260
|
residual_fn = GRUGating if gate_residual else Residual
|
1223
1261
|
residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
|
1224
1262
|
|
@@ -1248,8 +1286,8 @@ class AttentionLayers(Module):
|
|
1248
1286
|
self_attn_kv_mask = None,
|
1249
1287
|
mems = None,
|
1250
1288
|
mem_masks = None,
|
1251
|
-
seq_start_pos:
|
1252
|
-
cache:
|
1289
|
+
seq_start_pos: Tensor | None = None,
|
1290
|
+
cache: LayerIntermediates | None = None,
|
1253
1291
|
cache_age = 1,
|
1254
1292
|
return_hiddens = False,
|
1255
1293
|
rotary_pos_emb = None
|
@@ -1641,7 +1679,7 @@ class TransformerWrapper(Module):
|
|
1641
1679
|
return_attn_z_loss = False,
|
1642
1680
|
attn_z_loss_weight = 1e-4,
|
1643
1681
|
seq_start_pos = None,
|
1644
|
-
cache:
|
1682
|
+
cache: LayerIntermediates | None = None,
|
1645
1683
|
**kwargs
|
1646
1684
|
):
|
1647
1685
|
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
|
x_transformers/xval.py
CHANGED
@@ -176,11 +176,11 @@ class XValTransformerWrapper(nn.Module):
|
|
176
176
|
|
177
177
|
if return_mems:
|
178
178
|
hiddens = intermediates.hiddens
|
179
|
-
new_mems =
|
179
|
+
new_mems = tuple(t[..., -self.max_mem_len:, :].detach() for t in hiddens)
|
180
180
|
return out, new_mems
|
181
181
|
|
182
182
|
if return_attn:
|
183
|
-
attn_maps =
|
183
|
+
attn_maps = tuple(t.post_softmax_attn for t in intermediates.attn_intermediates)
|
184
184
|
return out, attn_maps
|
185
185
|
|
186
186
|
return out
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.30.1
|
4
4
|
Summary: X-Transformers - Pytorch
|
5
5
|
Home-page: https://github.com/lucidrains/x-transformers
|
6
6
|
Author: Phil Wang
|
@@ -14,6 +14,6 @@ Classifier: License :: OSI Approved :: MIT License
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.6
|
15
15
|
Description-Content-Type: text/markdown
|
16
16
|
License-File: LICENSE
|
17
|
-
Requires-Dist: torch >=
|
18
|
-
Requires-Dist: einops >=0.
|
17
|
+
Requires-Dist: torch >=2.0
|
18
|
+
Requires-Dist: einops >=0.8.0
|
19
19
|
|
@@ -0,0 +1,14 @@
|
|
1
|
+
x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
|
2
|
+
x_transformers/attend.py,sha256=Y9eE26I7BM8rGveabhiRhzw_xq9TY61Sp10QC1hX2O8,10192
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
|
+
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
|
+
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
|
+
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
+
x_transformers/x_transformers.py,sha256=EEfqwI-NANzrQf10Tc_bRSdjWOIEJdhxOfzeKY4osyI,66137
|
8
|
+
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
|
+
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
+
x_transformers-1.30.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.1.dist-info/METADATA,sha256=gkmRLAvk0l9_vkrTVBIWLnFq_cEtCYrI8oI3B07d9B8,661
|
12
|
+
x_transformers-1.30.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.1.dist-info/RECORD,,
|
@@ -1,14 +0,0 @@
|
|
1
|
-
x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
|
2
|
-
x_transformers/attend.py,sha256=L7vctHJ0PnECohu4cUu8yvY8cUrVyJxHmMFR0RGL0z4,10163
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
|
4
|
-
x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
|
5
|
-
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
|
-
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=vPt5x0Pg03xGf8t2rZGW0zPd8xP0uvGLQvROFlmmOao,65200
|
8
|
-
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
|
-
x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
|
10
|
-
x_transformers-1.29.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
-
x_transformers-1.29.2.dist-info/METADATA,sha256=0_ON52HHs50Dcwp4PMfGhLWhDGKC9Rd4V3QAvmqxGyo,661
|
12
|
-
x_transformers-1.29.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
-
x_transformers-1.29.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
-
x_transformers-1.29.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|