x-transformers 1.21.2__py3-none-any.whl → 1.21.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +51 -9
- x_transformers/xl_autoregressive_wrapper.py +6 -1
- {x_transformers-1.21.2.dist-info → x_transformers-1.21.4.dist-info}/METADATA +1 -1
- {x_transformers-1.21.2.dist-info → x_transformers-1.21.4.dist-info}/RECORD +7 -7
- {x_transformers-1.21.2.dist-info → x_transformers-1.21.4.dist-info}/LICENSE +0 -0
- {x_transformers-1.21.2.dist-info → x_transformers-1.21.4.dist-info}/WHEEL +0 -0
- {x_transformers-1.21.2.dist-info → x_transformers-1.21.4.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -11,7 +11,7 @@ from collections import namedtuple
|
|
11
11
|
from dataclasses import dataclass
|
12
12
|
from typing import List, Callable, Optional
|
13
13
|
|
14
|
-
from einops import rearrange, repeat, reduce
|
14
|
+
from einops import rearrange, repeat, reduce, pack, unpack
|
15
15
|
from einops.layers.torch import Rearrange
|
16
16
|
|
17
17
|
from x_transformers.attend import Attend, Intermediates, CascadingHeads
|
@@ -27,6 +27,7 @@ class LayerIntermediates:
|
|
27
27
|
attn_intermediates: Optional[List[Intermediates]] = None
|
28
28
|
layer_hiddens: Optional[List[Tensor]] = None
|
29
29
|
attn_z_loss: Optional[Tensor] = None
|
30
|
+
mems: Optional[Tensor] = None
|
30
31
|
|
31
32
|
# helpers
|
32
33
|
|
@@ -778,8 +779,8 @@ class Attention(nn.Module):
|
|
778
779
|
r_input = x
|
779
780
|
|
780
781
|
if exists(mem):
|
781
|
-
k_input =
|
782
|
-
v_input =
|
782
|
+
k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
|
783
|
+
v_input, _ = pack([mem, v_input], 'b * d')
|
783
784
|
|
784
785
|
q = self.to_q(q_input)
|
785
786
|
k = self.to_k(k_input)
|
@@ -792,11 +793,21 @@ class Attention(nn.Module):
|
|
792
793
|
|
793
794
|
if exists(cache) and not has_context:
|
794
795
|
ck, cv = cache.cached_kv
|
796
|
+
|
797
|
+
if exists(mem):
|
798
|
+
mk, k = unpack(k, mem_packed_shape, 'b h * d')
|
799
|
+
mv, v = unpack(v, mem_packed_shape, 'b h * d')
|
800
|
+
|
795
801
|
k = torch.cat((ck, k), dim = -2)
|
796
802
|
v = torch.cat((cv, v), dim = -2)
|
797
803
|
|
804
|
+
if exists(mem):
|
805
|
+
k = torch.cat((mk, k), dim = -2)
|
806
|
+
v = torch.cat((mv, v), dim = -2)
|
807
|
+
|
798
808
|
if return_intermediates:
|
799
|
-
|
809
|
+
mem_len = mem.shape[-2] if exists(mem) else 0
|
810
|
+
cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
|
800
811
|
|
801
812
|
if self.qk_norm:
|
802
813
|
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
@@ -941,6 +952,8 @@ class AttentionLayers(nn.Module):
|
|
941
952
|
custom_layers = None,
|
942
953
|
sandwich_coef = None,
|
943
954
|
par_ratio = None,
|
955
|
+
weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
|
956
|
+
layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
|
944
957
|
residual_attn = False,
|
945
958
|
cross_residual_attn = False,
|
946
959
|
macaron = False,
|
@@ -1046,6 +1059,15 @@ class AttentionLayers(nn.Module):
|
|
1046
1059
|
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
|
1047
1060
|
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
|
1048
1061
|
|
1062
|
+
# setup weight tying, which is a special case of `layer_execute_order`
|
1063
|
+
|
1064
|
+
assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
|
1065
|
+
|
1066
|
+
if weight_tie_layers:
|
1067
|
+
assert not exists(layers_execute_order)
|
1068
|
+
layers_execute_order = tuple(range(len(default_block))) * depth
|
1069
|
+
depth = 1
|
1070
|
+
|
1049
1071
|
# calculate layer block order
|
1050
1072
|
|
1051
1073
|
if exists(custom_layers):
|
@@ -1068,6 +1090,10 @@ class AttentionLayers(nn.Module):
|
|
1068
1090
|
layer_types = default_block * depth
|
1069
1091
|
|
1070
1092
|
self.layer_types = layer_types
|
1093
|
+
self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
|
1094
|
+
|
1095
|
+
assert all([i < len(self.layer_types) for i in self.layers_execute_order])
|
1096
|
+
|
1071
1097
|
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
1072
1098
|
|
1073
1099
|
# stochastic depth
|
@@ -1176,7 +1202,19 @@ class AttentionLayers(nn.Module):
|
|
1176
1202
|
|
1177
1203
|
outer_residual = x * self.resi_dual_scale
|
1178
1204
|
|
1179
|
-
|
1205
|
+
# get layers to be executed
|
1206
|
+
|
1207
|
+
layer_variables = (
|
1208
|
+
self.layer_types,
|
1209
|
+
self.layers,
|
1210
|
+
self.layer_dropouts
|
1211
|
+
)
|
1212
|
+
|
1213
|
+
layer_variables = tuple(tuple(layer_variable[i] for i in self.layers_execute_order) for layer_variable in layer_variables)
|
1214
|
+
|
1215
|
+
# go through the attention and feedforward layers
|
1216
|
+
|
1217
|
+
for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
|
1180
1218
|
is_last = ind == (len(self.layers) - 1)
|
1181
1219
|
|
1182
1220
|
if self.training and layer_dropout > 0. and random() < layer_dropout:
|
@@ -1475,14 +1513,18 @@ class TransformerWrapper(nn.Module):
|
|
1475
1513
|
intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
|
1476
1514
|
return_intermediates = True
|
1477
1515
|
|
1478
|
-
if return_intermediates:
|
1479
|
-
return out, intermediates
|
1480
|
-
|
1481
1516
|
if return_mems:
|
1482
1517
|
hiddens = intermediates.hiddens
|
1483
1518
|
new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
|
1484
1519
|
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
|
1485
|
-
|
1520
|
+
|
1521
|
+
if not return_intermediates:
|
1522
|
+
return out, new_mems
|
1523
|
+
|
1524
|
+
intermediates.mems = new_mems
|
1525
|
+
|
1526
|
+
if return_intermediates:
|
1527
|
+
return out, intermediates
|
1486
1528
|
|
1487
1529
|
if return_attn:
|
1488
1530
|
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
@@ -67,6 +67,7 @@ class XLAutoregressiveWrapper(nn.Module):
|
|
67
67
|
curr_pos = len(all_leading_tokens) * max_seq_len
|
68
68
|
curr_mems = mems
|
69
69
|
|
70
|
+
cache = None
|
70
71
|
out = start_tokens
|
71
72
|
|
72
73
|
for _ in range(seq_len):
|
@@ -75,13 +76,17 @@ class XLAutoregressiveWrapper(nn.Module):
|
|
75
76
|
|
76
77
|
x = out[:, curr_pos:]
|
77
78
|
|
78
|
-
logits,
|
79
|
+
logits, cache = self.net(
|
79
80
|
x,
|
80
81
|
mems = curr_mems,
|
82
|
+
cache = cache,
|
81
83
|
return_mems = True,
|
84
|
+
return_intermediates = True,
|
82
85
|
**kwargs
|
83
86
|
)
|
84
87
|
|
88
|
+
mems = cache.mems
|
89
|
+
|
85
90
|
logits = logits[:, -1]
|
86
91
|
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
|
87
92
|
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
@@ -3,10 +3,10 @@ x_transformers/attend.py,sha256=_LepMOwph_o3jio6tur8LEUBPM-2YIn7NpuBmqhU47E,1238
|
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=djA4nfE6_92SAzX1JI0KaC7krdLz1mvnZlaVOaerHDg,5372
|
4
4
|
x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
-
x_transformers/xl_autoregressive_wrapper.py,sha256
|
8
|
-
x_transformers-1.21.
|
9
|
-
x_transformers-1.21.
|
10
|
-
x_transformers-1.21.
|
11
|
-
x_transformers-1.21.
|
12
|
-
x_transformers-1.21.
|
6
|
+
x_transformers/x_transformers.py,sha256=vss6ISABCV74wgEapIt8nPK50j9-QU54hugRiaBh-sw,58088
|
7
|
+
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
|
+
x_transformers-1.21.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
+
x_transformers-1.21.4.dist-info/METADATA,sha256=HzzezgisQhEH2H6D2tI-JDqNghpuUi6pDlHg0AI976U,661
|
10
|
+
x_transformers-1.21.4.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
11
|
+
x_transformers-1.21.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
+
x_transformers-1.21.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|