x-transformers 2.3.7__py3-none-any.whl → 2.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +40 -21
- {x_transformers-2.3.7.dist-info → x_transformers-2.3.9.dist-info}/METADATA +1 -1
- {x_transformers-2.3.7.dist-info → x_transformers-2.3.9.dist-info}/RECORD +5 -5
- {x_transformers-2.3.7.dist-info → x_transformers-2.3.9.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.7.dist-info → x_transformers-2.3.9.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1208,8 +1208,6 @@ class FeedForward(Module):
|
|
1208
1208
|
sublayer_dropout = 0.,
|
1209
1209
|
no_bias = False,
|
1210
1210
|
zero_init_output = False,
|
1211
|
-
deep_embed_hiddens = False,
|
1212
|
-
deep_embed_num_tokens = None,
|
1213
1211
|
):
|
1214
1212
|
super().__init__()
|
1215
1213
|
inner_dim = int(dim * mult)
|
@@ -1242,17 +1240,6 @@ class FeedForward(Module):
|
|
1242
1240
|
nn.Dropout(sublayer_dropout) if sublayer_dropout > 0. else None
|
1243
1241
|
)
|
1244
1242
|
|
1245
|
-
# deep embed
|
1246
|
-
|
1247
|
-
# credit goes to Braden Koszarsky for first devising value embeddings in nanogpt-speedrun project
|
1248
|
-
# then Bo Peng for coming up with this alternate design in feedforward for RWKV 8
|
1249
|
-
# improvements were clearest to me (on my toy setup) with multiplying on output of feedforward, will try with attention at future date
|
1250
|
-
|
1251
|
-
self.deep_embed = None
|
1252
|
-
if deep_embed_hiddens:
|
1253
|
-
assert exists(deep_embed_num_tokens)
|
1254
|
-
self.deep_embed = nn.Parameter(torch.ones(deep_embed_num_tokens, dim_out))
|
1255
|
-
|
1256
1243
|
# init last linear layer to 0
|
1257
1244
|
|
1258
1245
|
if zero_init_output:
|
@@ -1261,12 +1248,11 @@ class FeedForward(Module):
|
|
1261
1248
|
def forward(
|
1262
1249
|
self,
|
1263
1250
|
x,
|
1264
|
-
|
1251
|
+
deep_embed = None
|
1265
1252
|
):
|
1266
1253
|
out = self.ff(x)
|
1267
1254
|
|
1268
|
-
if exists(
|
1269
|
-
deep_embed = self.deep_embed[deep_embed_ids]
|
1255
|
+
if exists(deep_embed):
|
1270
1256
|
out = out * deep_embed
|
1271
1257
|
|
1272
1258
|
return out
|
@@ -2380,7 +2366,7 @@ class AttentionLayers(Module):
|
|
2380
2366
|
pos = None,
|
2381
2367
|
context_pos = None,
|
2382
2368
|
attn_bias = None,
|
2383
|
-
|
2369
|
+
deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
|
2384
2370
|
condition = None,
|
2385
2371
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
2386
2372
|
layers_execute_order: tuple[int, ...] | None = None
|
@@ -2475,13 +2461,26 @@ class AttentionLayers(Module):
|
|
2475
2461
|
if cache_age > 0:
|
2476
2462
|
x = x[:, -cache_age:] # for spec decoding, may be greater than 1
|
2477
2463
|
|
2478
|
-
if exists(
|
2479
|
-
|
2464
|
+
if exists(deep_embeds_and_ids):
|
2465
|
+
deep_embeds, token_ids = deep_embeds_and_ids
|
2466
|
+
token_ids = token_ids[:, -cache_age:]
|
2467
|
+
deep_embeds_and_ids = (deep_embeds, token_ids)
|
2480
2468
|
|
2481
2469
|
attn_cache = cache.attn_intermediates
|
2482
2470
|
|
2483
2471
|
iter_attn_cache = iter(attn_cache)
|
2484
2472
|
|
2473
|
+
# handle deep embeds if needed
|
2474
|
+
|
2475
|
+
deep_embeds = []
|
2476
|
+
|
2477
|
+
if exists(deep_embeds_and_ids):
|
2478
|
+
deep_embeds, token_ids = deep_embeds_and_ids
|
2479
|
+
deep_embeds_across_depth = deep_embeds[token_ids]
|
2480
|
+
deep_embeds = rearrange(deep_embeds_across_depth, 'b n l d -> l b n d')
|
2481
|
+
|
2482
|
+
deep_embeds_iter = iter(deep_embeds)
|
2483
|
+
|
2485
2484
|
# setup multistreams if needed
|
2486
2485
|
|
2487
2486
|
streams = self.num_residual_streams
|
@@ -2602,7 +2601,7 @@ class AttentionLayers(Module):
|
|
2602
2601
|
elif layer_type == 'c':
|
2603
2602
|
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
2604
2603
|
elif layer_type == 'f':
|
2605
|
-
out = block(x,
|
2604
|
+
out = block(x, deep_embed = next(deep_embeds_iter, None))
|
2606
2605
|
|
2607
2606
|
# store first self or cross attention intermediate for value residual
|
2608
2607
|
|
@@ -2818,11 +2817,14 @@ class TransformerWrapper(Module):
|
|
2818
2817
|
mixture_of_softmax = False,
|
2819
2818
|
mixture_of_softmax_k = 4,
|
2820
2819
|
sigsoftmax_logits = False,
|
2820
|
+
ff_deep_embed = False,
|
2821
2821
|
to_logits: Module | None = None,
|
2822
2822
|
):
|
2823
2823
|
super().__init__()
|
2824
2824
|
|
2825
2825
|
dim = attn_layers.dim
|
2826
|
+
depth = attn_layers.depth
|
2827
|
+
|
2826
2828
|
emb_dim = default(emb_dim, dim)
|
2827
2829
|
self.emb_dim = emb_dim
|
2828
2830
|
self.num_tokens = num_tokens
|
@@ -2855,6 +2857,16 @@ class TransformerWrapper(Module):
|
|
2855
2857
|
if len(embed_num_tokens) > 0:
|
2856
2858
|
self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
|
2857
2859
|
|
2860
|
+
# deep embed
|
2861
|
+
|
2862
|
+
# credit goes to Braden Koszarsky for first devising value embeddings in nanogpt-speedrun project
|
2863
|
+
# then Bo Peng for coming up with this alternate design in feedforward for RWKV 8
|
2864
|
+
# improvements were clearest to me (on my toy setup) with multiplying on output of feedforward, will try with attention at future date
|
2865
|
+
|
2866
|
+
self.ff_deep_embed = None
|
2867
|
+
if ff_deep_embed:
|
2868
|
+
self.ff_deep_embed = nn.Parameter(torch.ones(num_tokens, depth, dim))
|
2869
|
+
|
2858
2870
|
# fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
|
2859
2871
|
|
2860
2872
|
self.emb_frac_gradient = emb_frac_gradient
|
@@ -3050,6 +3062,13 @@ class TransformerWrapper(Module):
|
|
3050
3062
|
|
3051
3063
|
x = self.project_emb(x)
|
3052
3064
|
|
3065
|
+
# maybe deep embeds
|
3066
|
+
|
3067
|
+
deep_embed_and_ids = None
|
3068
|
+
|
3069
|
+
if exists(self.ff_deep_embed):
|
3070
|
+
deep_embed_and_ids = (self.ff_deep_embed, token_ids)
|
3071
|
+
|
3053
3072
|
# maybe cls token
|
3054
3073
|
|
3055
3074
|
if exists(self.cls_token):
|
@@ -3096,7 +3115,7 @@ class TransformerWrapper(Module):
|
|
3096
3115
|
|
3097
3116
|
# regular
|
3098
3117
|
|
3099
|
-
attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache,
|
3118
|
+
attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, deep_embeds_and_ids = deep_embed_and_ids, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
3100
3119
|
|
3101
3120
|
else:
|
3102
3121
|
# recycling
|
@@ -8,10 +8,10 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
|
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=Wo5hauzdn4Q9PUVjBqQo-1vCq08BT2jYUDbq3r2a5Go,114061
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.9.dist-info/METADATA,sha256=M0rUg95el9swG5VFWas-5YjahLSgytT65W9L3Ne7BJM,89021
|
15
|
+
x_transformers-2.3.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.9.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|