x-transformers 1.36.0__py3-none-any.whl → 1.37.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/autoregressive_wrapper.py +3 -1
- x_transformers/x_transformers.py +74 -33
- {x_transformers-1.36.0.dist-info → x_transformers-1.37.1.dist-info}/METADATA +1 -1
- {x_transformers-1.36.0.dist-info → x_transformers-1.37.1.dist-info}/RECORD +7 -7
- {x_transformers-1.36.0.dist-info → x_transformers-1.37.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.36.0.dist-info → x_transformers-1.37.1.dist-info}/WHEEL +0 -0
- {x_transformers-1.36.0.dist-info → x_transformers-1.37.1.dist-info}/top_level.txt +0 -0
@@ -317,7 +317,9 @@ class AutoregressiveWrapper(Module):
|
|
317
317
|
**kwargs
|
318
318
|
)
|
319
319
|
|
320
|
-
|
320
|
+
loss_fn = F.cross_entropy if not self.net.is_log_prob else F.nll_loss
|
321
|
+
|
322
|
+
loss = loss_fn(
|
321
323
|
rearrange(logits, 'b n c -> b c n'),
|
322
324
|
target,
|
323
325
|
ignore_index = ignore_index
|
x_transformers/x_transformers.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
+
from typing import Callable
|
2
3
|
|
3
4
|
import math
|
4
5
|
from random import random, randrange
|
@@ -14,7 +15,6 @@ from functools import partial, wraps
|
|
14
15
|
from collections import namedtuple
|
15
16
|
from contextlib import nullcontext
|
16
17
|
from dataclasses import dataclass
|
17
|
-
from typing import List, Dict, Tuple, Callable
|
18
18
|
|
19
19
|
from einops import rearrange, repeat, reduce, pack, unpack
|
20
20
|
from einops.layers.torch import Rearrange
|
@@ -28,14 +28,16 @@ DEFAULT_DIM_HEAD = 64
|
|
28
28
|
|
29
29
|
@dataclass
|
30
30
|
class LayerIntermediates:
|
31
|
-
hiddens:
|
31
|
+
hiddens: list[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture)
|
32
32
|
last_hidden: Tensor | None = None # very last hidden after all attention layers, after the final norm
|
33
|
-
attn_intermediates:
|
34
|
-
layer_hiddens:
|
33
|
+
attn_intermediates: list[Intermediates] | None = None
|
34
|
+
layer_hiddens: list[Tensor] | None = None
|
35
35
|
attn_z_loss: Tensor | None = None
|
36
36
|
mems: Tensor | None = None
|
37
37
|
memory_tokens: Tensor | None = None
|
38
38
|
|
39
|
+
LinearNoBias = partial(nn.Linear, bias = False)
|
40
|
+
|
39
41
|
# helpers
|
40
42
|
|
41
43
|
def exists(val):
|
@@ -92,6 +94,9 @@ def Sequential(*modules):
|
|
92
94
|
|
93
95
|
# tensor helpers
|
94
96
|
|
97
|
+
def log(t, eps = 1e-20):
|
98
|
+
return t.clamp(min = eps).log()
|
99
|
+
|
95
100
|
def max_neg_value(tensor):
|
96
101
|
return -torch.finfo(tensor.dtype).max
|
97
102
|
|
@@ -114,7 +119,7 @@ def masked_mean(t, mask = None, dim = 1):
|
|
114
119
|
den = mask.sum(dim = dim).clamp(min = 1.)
|
115
120
|
return num / den
|
116
121
|
|
117
|
-
def pad_at_dim(t, pad:
|
122
|
+
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
|
118
123
|
if pad == (0, 0):
|
119
124
|
return t
|
120
125
|
|
@@ -131,7 +136,7 @@ def or_reduce(masks):
|
|
131
136
|
# auxiliary loss helpers
|
132
137
|
|
133
138
|
def calc_z_loss(
|
134
|
-
pre_softmax_attns:
|
139
|
+
pre_softmax_attns: list[Tensor],
|
135
140
|
mask = None,
|
136
141
|
weight = 1.
|
137
142
|
):
|
@@ -611,7 +616,7 @@ class AdaptiveLayerNorm(Module):
|
|
611
616
|
dim_condition = default(dim_condition, dim)
|
612
617
|
|
613
618
|
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
614
|
-
self.to_gamma =
|
619
|
+
self.to_gamma = LinearNoBias(dim_condition, dim)
|
615
620
|
nn.init.zeros_(self.to_gamma.weight)
|
616
621
|
|
617
622
|
def forward(self, x, *, condition):
|
@@ -666,7 +671,7 @@ class AdaptiveRMSNorm(Module):
|
|
666
671
|
self.scale = dim ** 0.5
|
667
672
|
dim_condition = default(dim_condition, dim)
|
668
673
|
|
669
|
-
self.to_gamma =
|
674
|
+
self.to_gamma = LinearNoBias(dim_condition, dim)
|
670
675
|
nn.init.zeros_(self.to_gamma.weight)
|
671
676
|
|
672
677
|
def forward(self, x, *, condition):
|
@@ -749,7 +754,7 @@ class ShiftTokens(Module):
|
|
749
754
|
feats_per_shift = x.shape[-1] // segments
|
750
755
|
splitted = x.split(feats_per_shift, dim = -1)
|
751
756
|
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
752
|
-
segments_to_shift =
|
757
|
+
segments_to_shift = [shift(*args, mask = mask) for args in zip(segments_to_shift, shifts)]
|
753
758
|
x = torch.cat((*segments_to_shift, *rest), dim = -1)
|
754
759
|
return self.fn(x, **kwargs)
|
755
760
|
|
@@ -817,7 +822,7 @@ class ConcatCombine(Module):
|
|
817
822
|
def __init__(self, dim, prev_layer_ind):
|
818
823
|
super().__init__()
|
819
824
|
self.prev_layer_ind = prev_layer_ind
|
820
|
-
self.combine =
|
825
|
+
self.combine = LinearNoBias(dim * 2, dim)
|
821
826
|
|
822
827
|
def forward(self, x, prev_layers: list[Tensor]):
|
823
828
|
skip = prev_layers[self.prev_layer_ind]
|
@@ -957,17 +962,17 @@ class Attention(Module):
|
|
957
962
|
v_dim = value_dim_head * kv_heads
|
958
963
|
out_dim = value_dim_head * heads
|
959
964
|
|
960
|
-
self.to_q =
|
961
|
-
self.to_k =
|
965
|
+
self.to_q = LinearNoBias(dim, q_dim)
|
966
|
+
self.to_k = LinearNoBias(dim_kv, k_dim)
|
962
967
|
|
963
968
|
# shared key / values, for further memory savings during inference
|
964
969
|
|
965
970
|
assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
|
966
|
-
self.to_v =
|
971
|
+
self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
|
967
972
|
|
968
973
|
# relations projection from tp-attention
|
969
974
|
|
970
|
-
self.to_r =
|
975
|
+
self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
|
971
976
|
|
972
977
|
# add GLU gating for aggregated values, from alphafold2
|
973
978
|
|
@@ -1063,7 +1068,7 @@ class Attention(Module):
|
|
1063
1068
|
# output dimension by default same as input, but can be overridden
|
1064
1069
|
|
1065
1070
|
dim_out = default(dim_out, dim)
|
1066
|
-
self.to_out = nn.Sequential(
|
1071
|
+
self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)
|
1067
1072
|
|
1068
1073
|
# whether to rotate positions into values, for absolute positions in addition to relative
|
1069
1074
|
|
@@ -1109,7 +1114,7 @@ class Attention(Module):
|
|
1109
1114
|
|
1110
1115
|
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
1111
1116
|
|
1112
|
-
k, v, r =
|
1117
|
+
k, v, r = tuple(maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h) for t in (k, v, r))
|
1113
1118
|
|
1114
1119
|
if exists(cache):
|
1115
1120
|
ck, cv = cache.cached_kv
|
@@ -1164,12 +1169,12 @@ class Attention(Module):
|
|
1164
1169
|
|
1165
1170
|
# i, j determined for relative positional bias, excluding memory key / values
|
1166
1171
|
|
1167
|
-
i, j =
|
1172
|
+
i, j = tuple(t.shape[-2] for t in (q, k))
|
1168
1173
|
|
1169
1174
|
# maybe append memory key / values
|
1170
1175
|
|
1171
1176
|
if num_mem_kv > 0:
|
1172
|
-
mem_k, mem_v =
|
1177
|
+
mem_k, mem_v = tuple(repeat(t, 'h n d -> b h n d', b = b) for t in (self.mem_k, self.mem_v))
|
1173
1178
|
|
1174
1179
|
if self.qk_norm:
|
1175
1180
|
mem_k = l2norm(mem_k)
|
@@ -1302,8 +1307,8 @@ class AttentionLayers(Module):
|
|
1302
1307
|
rotary_xpos_scale_base = 512,
|
1303
1308
|
rotary_base_rescale_factor = 1.,
|
1304
1309
|
weight_tie_layers = False,
|
1305
|
-
custom_layers:
|
1306
|
-
layers_execute_order:
|
1310
|
+
custom_layers: tuple[str, ...] | None = None,
|
1311
|
+
layers_execute_order: tuple[int, ...] | None = None,
|
1307
1312
|
sandwich_coef = None,
|
1308
1313
|
par_ratio = None,
|
1309
1314
|
residual_attn = False,
|
@@ -1464,7 +1469,7 @@ class AttentionLayers(Module):
|
|
1464
1469
|
|
1465
1470
|
if self.need_condition and adaptive_condition_mlp:
|
1466
1471
|
self.adaptive_mlp = nn.Sequential(
|
1467
|
-
|
1472
|
+
LinearNoBias(dim_condition, dim_condition * dim_condition_mult),
|
1468
1473
|
nn.SiLU()
|
1469
1474
|
)
|
1470
1475
|
|
@@ -1635,7 +1640,7 @@ class AttentionLayers(Module):
|
|
1635
1640
|
return_hiddens = False,
|
1636
1641
|
rotary_pos_emb = None,
|
1637
1642
|
condition = None,
|
1638
|
-
layers_execute_order:
|
1643
|
+
layers_execute_order: tuple[int, ...] | None = None
|
1639
1644
|
):
|
1640
1645
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
1641
1646
|
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
|
@@ -1973,7 +1978,7 @@ class TransformerWrapper(Module):
|
|
1973
1978
|
num_tokens,
|
1974
1979
|
max_seq_len,
|
1975
1980
|
attn_layers: AttentionLayers,
|
1976
|
-
embed_num_tokens:
|
1981
|
+
embed_num_tokens: dict[str, int] = dict(),
|
1977
1982
|
emb_dim = None,
|
1978
1983
|
max_mem_len = 0,
|
1979
1984
|
shift_mem_down = 0,
|
@@ -1996,6 +2001,8 @@ class TransformerWrapper(Module):
|
|
1996
2001
|
use_cls_token = False,
|
1997
2002
|
squeeze_out_last_dim = False,
|
1998
2003
|
token_emb: TokenEmbedding | None = None,
|
2004
|
+
mixture_of_softmax = False,
|
2005
|
+
mixture_of_softmax_k = 4,
|
1999
2006
|
):
|
2000
2007
|
super().__init__()
|
2001
2008
|
|
@@ -2050,7 +2057,7 @@ class TransformerWrapper(Module):
|
|
2050
2057
|
# maybe recycling
|
2051
2058
|
|
2052
2059
|
self.recycling = recycling
|
2053
|
-
self.recycled_proj =
|
2060
|
+
self.recycled_proj = LinearNoBias(dim, dim) if recycling else None
|
2054
2061
|
|
2055
2062
|
self.train_max_recycle_steps = train_max_recycle_steps
|
2056
2063
|
|
@@ -2066,21 +2073,37 @@ class TransformerWrapper(Module):
|
|
2066
2073
|
|
2067
2074
|
self.average_pool_embed = average_pool_embed
|
2068
2075
|
|
2076
|
+
# output type
|
2077
|
+
|
2078
|
+
self.is_log_prob = mixture_of_softmax
|
2079
|
+
|
2080
|
+
self.to_mixture = None
|
2081
|
+
self.combine_mixture = None
|
2082
|
+
|
2083
|
+
if mixture_of_softmax:
|
2084
|
+
assert num_output_heads == 1
|
2085
|
+
|
2086
|
+
self.to_mixture = Sequential(
|
2087
|
+
LinearNoBias(dim, dim * mixture_of_softmax_k),
|
2088
|
+
Rearrange('... (k d) -> ... k d', k = mixture_of_softmax_k)
|
2089
|
+
)
|
2090
|
+
|
2091
|
+
self.combine_mixture = LinearNoBias(dim, mixture_of_softmax_k)
|
2092
|
+
|
2069
2093
|
# output head, usually to logits of num_tokens
|
2070
2094
|
|
2071
2095
|
logits_dim = default(logits_dim, num_tokens)
|
2072
2096
|
|
2073
|
-
self.has_multiple_heads =
|
2097
|
+
self.has_multiple_heads = num_output_heads > 1
|
2074
2098
|
|
2075
2099
|
if return_only_embed:
|
2076
2100
|
self.to_logits = None
|
2077
2101
|
elif tie_embedding:
|
2078
2102
|
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
|
2079
2103
|
elif num_output_heads > 1:
|
2080
|
-
self.
|
2081
|
-
self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
|
2104
|
+
self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)])
|
2082
2105
|
else:
|
2083
|
-
self.to_logits =
|
2106
|
+
self.to_logits = LinearNoBias(dim, logits_dim)
|
2084
2107
|
|
2085
2108
|
# memory tokens (like [cls]) from Memory Transformers paper
|
2086
2109
|
|
@@ -2124,7 +2147,7 @@ class TransformerWrapper(Module):
|
|
2124
2147
|
pos = None,
|
2125
2148
|
prepend_embeds = None,
|
2126
2149
|
prepend_mask = None,
|
2127
|
-
embed_ids:
|
2150
|
+
embed_ids: dict[str, Tensor] = dict(),
|
2128
2151
|
sum_embeds = None,
|
2129
2152
|
return_attn_z_loss = False,
|
2130
2153
|
attn_z_loss_weight = 1e-4,
|
@@ -2235,6 +2258,8 @@ class TransformerWrapper(Module):
|
|
2235
2258
|
# attention layers
|
2236
2259
|
|
2237
2260
|
if not self.recycling:
|
2261
|
+
assert recycle_steps == 1, 'you did not train with recycling'
|
2262
|
+
|
2238
2263
|
# regular
|
2239
2264
|
|
2240
2265
|
attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
@@ -2281,6 +2306,14 @@ class TransformerWrapper(Module):
|
|
2281
2306
|
if exists(self.cls_token):
|
2282
2307
|
x, _ = unpack(x, cls_packed_shape, 'b * d')
|
2283
2308
|
|
2309
|
+
# handle expansion to mixture if needed (for mixture of softmax)
|
2310
|
+
|
2311
|
+
combine_mixture = None
|
2312
|
+
|
2313
|
+
if exists(self.to_mixture):
|
2314
|
+
combine_mixture = self.combine_mixture(x).softmax(dim = -1)
|
2315
|
+
x = self.to_mixture(x)
|
2316
|
+
|
2284
2317
|
# projecting to logits
|
2285
2318
|
|
2286
2319
|
if not return_embeddings:
|
@@ -2289,6 +2322,14 @@ class TransformerWrapper(Module):
|
|
2289
2322
|
else:
|
2290
2323
|
logits = self.to_logits(x)
|
2291
2324
|
|
2325
|
+
# handle maybe combine mixture
|
2326
|
+
|
2327
|
+
if exists(combine_mixture):
|
2328
|
+
with autocast('cuda', enabled = False):
|
2329
|
+
prob = logits.softmax(dim = -1)
|
2330
|
+
mos = einsum('... k d, ... k -> ... d', prob, combine_mixture)
|
2331
|
+
logits = log(mos)
|
2332
|
+
|
2292
2333
|
# maybe squeeze out last dimension of logits
|
2293
2334
|
|
2294
2335
|
if self.squeeze_out_last_dim:
|
@@ -2309,14 +2350,14 @@ class TransformerWrapper(Module):
|
|
2309
2350
|
# aux loss
|
2310
2351
|
|
2311
2352
|
if return_attn_z_loss:
|
2312
|
-
pre_softmax_attns =
|
2353
|
+
pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates]
|
2313
2354
|
intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
|
2314
2355
|
return_intermediates = True
|
2315
2356
|
|
2316
2357
|
if return_mems:
|
2317
2358
|
hiddens = intermediates.hiddens
|
2318
|
-
new_mems =
|
2319
|
-
new_mems =
|
2359
|
+
new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
|
2360
|
+
new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
|
2320
2361
|
|
2321
2362
|
if not return_intermediates:
|
2322
2363
|
return out, new_mems
|
@@ -2327,7 +2368,7 @@ class TransformerWrapper(Module):
|
|
2327
2368
|
return out, intermediates
|
2328
2369
|
|
2329
2370
|
if return_attn:
|
2330
|
-
attn_maps =
|
2371
|
+
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
|
2331
2372
|
return out, attn_maps
|
2332
2373
|
|
2333
2374
|
return out
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
2
|
x_transformers/attend.py,sha256=7q996VGYHGIsc0FQnN8WNiwHn3xny3i1biRwx7yW5vg,12090
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=2FN4ZobFcdDGDGWEnUof_geb16dRGSJycZGwG899Pa4,10493
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=9lk6wtz0vNigyLoMWleo442Q0mhce-BCxEhazhSHuvI,83356
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
11
|
+
x_transformers-1.37.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.37.1.dist-info/METADATA,sha256=ik8UKwzq_pW9zdxCl6pt7POrjRC7_GwIi6gAnY7Fck0,661
|
13
|
+
x_transformers-1.37.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.37.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.37.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|