x-transformers 1.36.0__py3-none-any.whl → 1.37.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/autoregressive_wrapper.py +3 -1
- x_transformers/x_transformers.py +72 -33
- {x_transformers-1.36.0.dist-info → x_transformers-1.37.0.dist-info}/METADATA +1 -1
- {x_transformers-1.36.0.dist-info → x_transformers-1.37.0.dist-info}/RECORD +7 -7
- {x_transformers-1.36.0.dist-info → x_transformers-1.37.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.36.0.dist-info → x_transformers-1.37.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.36.0.dist-info → x_transformers-1.37.0.dist-info}/top_level.txt +0 -0
@@ -317,7 +317,9 @@ class AutoregressiveWrapper(Module):
|
|
317
317
|
**kwargs
|
318
318
|
)
|
319
319
|
|
320
|
-
|
320
|
+
loss_fn = F.cross_entropy if not self.net.is_log_prob else F.nll_loss
|
321
|
+
|
322
|
+
loss = loss_fn(
|
321
323
|
rearrange(logits, 'b n c -> b c n'),
|
322
324
|
target,
|
323
325
|
ignore_index = ignore_index
|
x_transformers/x_transformers.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
+
from typing import Callable
|
2
3
|
|
3
4
|
import math
|
4
5
|
from random import random, randrange
|
@@ -14,7 +15,6 @@ from functools import partial, wraps
|
|
14
15
|
from collections import namedtuple
|
15
16
|
from contextlib import nullcontext
|
16
17
|
from dataclasses import dataclass
|
17
|
-
from typing import List, Dict, Tuple, Callable
|
18
18
|
|
19
19
|
from einops import rearrange, repeat, reduce, pack, unpack
|
20
20
|
from einops.layers.torch import Rearrange
|
@@ -28,14 +28,16 @@ DEFAULT_DIM_HEAD = 64
|
|
28
28
|
|
29
29
|
@dataclass
|
30
30
|
class LayerIntermediates:
|
31
|
-
hiddens:
|
31
|
+
hiddens: list[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture)
|
32
32
|
last_hidden: Tensor | None = None # very last hidden after all attention layers, after the final norm
|
33
|
-
attn_intermediates:
|
34
|
-
layer_hiddens:
|
33
|
+
attn_intermediates: list[Intermediates] | None = None
|
34
|
+
layer_hiddens: list[Tensor] | None = None
|
35
35
|
attn_z_loss: Tensor | None = None
|
36
36
|
mems: Tensor | None = None
|
37
37
|
memory_tokens: Tensor | None = None
|
38
38
|
|
39
|
+
LinearNoBias = partial(nn.Linear, bias = False)
|
40
|
+
|
39
41
|
# helpers
|
40
42
|
|
41
43
|
def exists(val):
|
@@ -92,6 +94,9 @@ def Sequential(*modules):
|
|
92
94
|
|
93
95
|
# tensor helpers
|
94
96
|
|
97
|
+
def log(t, eps = 1e-20):
|
98
|
+
return t.clamp(min = eps).log()
|
99
|
+
|
95
100
|
def max_neg_value(tensor):
|
96
101
|
return -torch.finfo(tensor.dtype).max
|
97
102
|
|
@@ -114,7 +119,7 @@ def masked_mean(t, mask = None, dim = 1):
|
|
114
119
|
den = mask.sum(dim = dim).clamp(min = 1.)
|
115
120
|
return num / den
|
116
121
|
|
117
|
-
def pad_at_dim(t, pad:
|
122
|
+
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
|
118
123
|
if pad == (0, 0):
|
119
124
|
return t
|
120
125
|
|
@@ -131,7 +136,7 @@ def or_reduce(masks):
|
|
131
136
|
# auxiliary loss helpers
|
132
137
|
|
133
138
|
def calc_z_loss(
|
134
|
-
pre_softmax_attns:
|
139
|
+
pre_softmax_attns: list[Tensor],
|
135
140
|
mask = None,
|
136
141
|
weight = 1.
|
137
142
|
):
|
@@ -611,7 +616,7 @@ class AdaptiveLayerNorm(Module):
|
|
611
616
|
dim_condition = default(dim_condition, dim)
|
612
617
|
|
613
618
|
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
614
|
-
self.to_gamma =
|
619
|
+
self.to_gamma = LinearNoBias(dim_condition, dim)
|
615
620
|
nn.init.zeros_(self.to_gamma.weight)
|
616
621
|
|
617
622
|
def forward(self, x, *, condition):
|
@@ -666,7 +671,7 @@ class AdaptiveRMSNorm(Module):
|
|
666
671
|
self.scale = dim ** 0.5
|
667
672
|
dim_condition = default(dim_condition, dim)
|
668
673
|
|
669
|
-
self.to_gamma =
|
674
|
+
self.to_gamma = LinearNoBias(dim_condition, dim)
|
670
675
|
nn.init.zeros_(self.to_gamma.weight)
|
671
676
|
|
672
677
|
def forward(self, x, *, condition):
|
@@ -749,7 +754,7 @@ class ShiftTokens(Module):
|
|
749
754
|
feats_per_shift = x.shape[-1] // segments
|
750
755
|
splitted = x.split(feats_per_shift, dim = -1)
|
751
756
|
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
752
|
-
segments_to_shift =
|
757
|
+
segments_to_shift = [shift(*args, mask = mask) for args in zip(segments_to_shift, shifts)]
|
753
758
|
x = torch.cat((*segments_to_shift, *rest), dim = -1)
|
754
759
|
return self.fn(x, **kwargs)
|
755
760
|
|
@@ -817,7 +822,7 @@ class ConcatCombine(Module):
|
|
817
822
|
def __init__(self, dim, prev_layer_ind):
|
818
823
|
super().__init__()
|
819
824
|
self.prev_layer_ind = prev_layer_ind
|
820
|
-
self.combine =
|
825
|
+
self.combine = LinearNoBias(dim * 2, dim)
|
821
826
|
|
822
827
|
def forward(self, x, prev_layers: list[Tensor]):
|
823
828
|
skip = prev_layers[self.prev_layer_ind]
|
@@ -957,17 +962,17 @@ class Attention(Module):
|
|
957
962
|
v_dim = value_dim_head * kv_heads
|
958
963
|
out_dim = value_dim_head * heads
|
959
964
|
|
960
|
-
self.to_q =
|
961
|
-
self.to_k =
|
965
|
+
self.to_q = LinearNoBias(dim, q_dim)
|
966
|
+
self.to_k = LinearNoBias(dim_kv, k_dim)
|
962
967
|
|
963
968
|
# shared key / values, for further memory savings during inference
|
964
969
|
|
965
970
|
assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
|
966
|
-
self.to_v =
|
971
|
+
self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
|
967
972
|
|
968
973
|
# relations projection from tp-attention
|
969
974
|
|
970
|
-
self.to_r =
|
975
|
+
self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
|
971
976
|
|
972
977
|
# add GLU gating for aggregated values, from alphafold2
|
973
978
|
|
@@ -1063,7 +1068,7 @@ class Attention(Module):
|
|
1063
1068
|
# output dimension by default same as input, but can be overridden
|
1064
1069
|
|
1065
1070
|
dim_out = default(dim_out, dim)
|
1066
|
-
self.to_out = nn.Sequential(
|
1071
|
+
self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)
|
1067
1072
|
|
1068
1073
|
# whether to rotate positions into values, for absolute positions in addition to relative
|
1069
1074
|
|
@@ -1109,7 +1114,7 @@ class Attention(Module):
|
|
1109
1114
|
|
1110
1115
|
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
1111
1116
|
|
1112
|
-
k, v, r =
|
1117
|
+
k, v, r = tuple(maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h) for t in (k, v, r))
|
1113
1118
|
|
1114
1119
|
if exists(cache):
|
1115
1120
|
ck, cv = cache.cached_kv
|
@@ -1164,12 +1169,12 @@ class Attention(Module):
|
|
1164
1169
|
|
1165
1170
|
# i, j determined for relative positional bias, excluding memory key / values
|
1166
1171
|
|
1167
|
-
i, j =
|
1172
|
+
i, j = tuple(t.shape[-2] for t in (q, k))
|
1168
1173
|
|
1169
1174
|
# maybe append memory key / values
|
1170
1175
|
|
1171
1176
|
if num_mem_kv > 0:
|
1172
|
-
mem_k, mem_v =
|
1177
|
+
mem_k, mem_v = tuple(repeat(t, 'h n d -> b h n d', b = b) for t in (self.mem_k, self.mem_v))
|
1173
1178
|
|
1174
1179
|
if self.qk_norm:
|
1175
1180
|
mem_k = l2norm(mem_k)
|
@@ -1302,8 +1307,8 @@ class AttentionLayers(Module):
|
|
1302
1307
|
rotary_xpos_scale_base = 512,
|
1303
1308
|
rotary_base_rescale_factor = 1.,
|
1304
1309
|
weight_tie_layers = False,
|
1305
|
-
custom_layers:
|
1306
|
-
layers_execute_order:
|
1310
|
+
custom_layers: tuple[str, ...] | None = None,
|
1311
|
+
layers_execute_order: tuple[int, ...] | None = None,
|
1307
1312
|
sandwich_coef = None,
|
1308
1313
|
par_ratio = None,
|
1309
1314
|
residual_attn = False,
|
@@ -1464,7 +1469,7 @@ class AttentionLayers(Module):
|
|
1464
1469
|
|
1465
1470
|
if self.need_condition and adaptive_condition_mlp:
|
1466
1471
|
self.adaptive_mlp = nn.Sequential(
|
1467
|
-
|
1472
|
+
LinearNoBias(dim_condition, dim_condition * dim_condition_mult),
|
1468
1473
|
nn.SiLU()
|
1469
1474
|
)
|
1470
1475
|
|
@@ -1635,7 +1640,7 @@ class AttentionLayers(Module):
|
|
1635
1640
|
return_hiddens = False,
|
1636
1641
|
rotary_pos_emb = None,
|
1637
1642
|
condition = None,
|
1638
|
-
layers_execute_order:
|
1643
|
+
layers_execute_order: tuple[int, ...] | None = None
|
1639
1644
|
):
|
1640
1645
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
1641
1646
|
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
|
@@ -1973,7 +1978,7 @@ class TransformerWrapper(Module):
|
|
1973
1978
|
num_tokens,
|
1974
1979
|
max_seq_len,
|
1975
1980
|
attn_layers: AttentionLayers,
|
1976
|
-
embed_num_tokens:
|
1981
|
+
embed_num_tokens: dict[str, int] = dict(),
|
1977
1982
|
emb_dim = None,
|
1978
1983
|
max_mem_len = 0,
|
1979
1984
|
shift_mem_down = 0,
|
@@ -1996,6 +2001,8 @@ class TransformerWrapper(Module):
|
|
1996
2001
|
use_cls_token = False,
|
1997
2002
|
squeeze_out_last_dim = False,
|
1998
2003
|
token_emb: TokenEmbedding | None = None,
|
2004
|
+
mixture_of_softmax = False,
|
2005
|
+
mixture_of_softmax_k = 4,
|
1999
2006
|
):
|
2000
2007
|
super().__init__()
|
2001
2008
|
|
@@ -2050,7 +2057,7 @@ class TransformerWrapper(Module):
|
|
2050
2057
|
# maybe recycling
|
2051
2058
|
|
2052
2059
|
self.recycling = recycling
|
2053
|
-
self.recycled_proj =
|
2060
|
+
self.recycled_proj = LinearNoBias(dim, dim) if recycling else None
|
2054
2061
|
|
2055
2062
|
self.train_max_recycle_steps = train_max_recycle_steps
|
2056
2063
|
|
@@ -2066,21 +2073,37 @@ class TransformerWrapper(Module):
|
|
2066
2073
|
|
2067
2074
|
self.average_pool_embed = average_pool_embed
|
2068
2075
|
|
2076
|
+
# output type
|
2077
|
+
|
2078
|
+
self.is_log_prob = mixture_of_softmax
|
2079
|
+
|
2080
|
+
self.to_mixture = None
|
2081
|
+
self.combine_mixture = None
|
2082
|
+
|
2083
|
+
if mixture_of_softmax:
|
2084
|
+
assert num_output_heads == 1
|
2085
|
+
|
2086
|
+
self.to_mixture = Sequential(
|
2087
|
+
LinearNoBias(dim, dim * mixture_of_softmax_k),
|
2088
|
+
Rearrange('... (k d) -> ... k d', k = mixture_of_softmax_k)
|
2089
|
+
)
|
2090
|
+
|
2091
|
+
self.combine_mixture = LinearNoBias(dim, mixture_of_softmax_k)
|
2092
|
+
|
2069
2093
|
# output head, usually to logits of num_tokens
|
2070
2094
|
|
2071
2095
|
logits_dim = default(logits_dim, num_tokens)
|
2072
2096
|
|
2073
|
-
self.has_multiple_heads =
|
2097
|
+
self.has_multiple_heads = num_output_heads > 1
|
2074
2098
|
|
2075
2099
|
if return_only_embed:
|
2076
2100
|
self.to_logits = None
|
2077
2101
|
elif tie_embedding:
|
2078
2102
|
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
|
2079
2103
|
elif num_output_heads > 1:
|
2080
|
-
self.
|
2081
|
-
self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
|
2104
|
+
self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)])
|
2082
2105
|
else:
|
2083
|
-
self.to_logits =
|
2106
|
+
self.to_logits = LinearNoBias(dim, logits_dim)
|
2084
2107
|
|
2085
2108
|
# memory tokens (like [cls]) from Memory Transformers paper
|
2086
2109
|
|
@@ -2124,7 +2147,7 @@ class TransformerWrapper(Module):
|
|
2124
2147
|
pos = None,
|
2125
2148
|
prepend_embeds = None,
|
2126
2149
|
prepend_mask = None,
|
2127
|
-
embed_ids:
|
2150
|
+
embed_ids: dict[str, Tensor] = dict(),
|
2128
2151
|
sum_embeds = None,
|
2129
2152
|
return_attn_z_loss = False,
|
2130
2153
|
attn_z_loss_weight = 1e-4,
|
@@ -2281,6 +2304,14 @@ class TransformerWrapper(Module):
|
|
2281
2304
|
if exists(self.cls_token):
|
2282
2305
|
x, _ = unpack(x, cls_packed_shape, 'b * d')
|
2283
2306
|
|
2307
|
+
# handle expansion to mixture if needed (for mixture of softmax)
|
2308
|
+
|
2309
|
+
combine_mixture = None
|
2310
|
+
|
2311
|
+
if exists(self.to_mixture):
|
2312
|
+
combine_mixture = self.combine_mixture(x).softmax(dim = -1)
|
2313
|
+
x = self.to_mixture(x)
|
2314
|
+
|
2284
2315
|
# projecting to logits
|
2285
2316
|
|
2286
2317
|
if not return_embeddings:
|
@@ -2289,6 +2320,14 @@ class TransformerWrapper(Module):
|
|
2289
2320
|
else:
|
2290
2321
|
logits = self.to_logits(x)
|
2291
2322
|
|
2323
|
+
# handle maybe combine mixture
|
2324
|
+
|
2325
|
+
if exists(combine_mixture):
|
2326
|
+
with autocast('cuda', enabled = False):
|
2327
|
+
prob = logits.softmax(dim = -1)
|
2328
|
+
mos = einsum('... k d, ... k -> ... d', prob, combine_mixture)
|
2329
|
+
logits = log(mos)
|
2330
|
+
|
2292
2331
|
# maybe squeeze out last dimension of logits
|
2293
2332
|
|
2294
2333
|
if self.squeeze_out_last_dim:
|
@@ -2309,14 +2348,14 @@ class TransformerWrapper(Module):
|
|
2309
2348
|
# aux loss
|
2310
2349
|
|
2311
2350
|
if return_attn_z_loss:
|
2312
|
-
pre_softmax_attns =
|
2351
|
+
pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates]
|
2313
2352
|
intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
|
2314
2353
|
return_intermediates = True
|
2315
2354
|
|
2316
2355
|
if return_mems:
|
2317
2356
|
hiddens = intermediates.hiddens
|
2318
|
-
new_mems =
|
2319
|
-
new_mems =
|
2357
|
+
new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
|
2358
|
+
new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
|
2320
2359
|
|
2321
2360
|
if not return_intermediates:
|
2322
2361
|
return out, new_mems
|
@@ -2327,7 +2366,7 @@ class TransformerWrapper(Module):
|
|
2327
2366
|
return out, intermediates
|
2328
2367
|
|
2329
2368
|
if return_attn:
|
2330
|
-
attn_maps =
|
2369
|
+
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
|
2331
2370
|
return out, attn_maps
|
2332
2371
|
|
2333
2372
|
return out
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
2
|
x_transformers/attend.py,sha256=7q996VGYHGIsc0FQnN8WNiwHn3xny3i1biRwx7yW5vg,12090
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=2FN4ZobFcdDGDGWEnUof_geb16dRGSJycZGwG899Pa4,10493
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=ztP6nNncVoPONR-al5lHIphAJQqNcE0mrT6tFWsnyPk,83281
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
11
|
+
x_transformers-1.37.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.37.0.dist-info/METADATA,sha256=S8fQ4scePXn4pMl1_01cyWU8_3UXXBlLczibRSFuOoM,661
|
13
|
+
x_transformers-1.37.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.37.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.37.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|