x-transformers 2.3.5__py3-none-any.whl → 2.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +0 -2
- x_transformers/x_transformers.py +41 -11
- {x_transformers-2.3.5.dist-info → x_transformers-2.3.7.dist-info}/METADATA +12 -1
- {x_transformers-2.3.5.dist-info → x_transformers-2.3.7.dist-info}/RECORD +6 -6
- {x_transformers-2.3.5.dist-info → x_transformers-2.3.7.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.5.dist-info → x_transformers-2.3.7.dist-info}/licenses/LICENSE +0 -0
x_transformers/continuous.py
CHANGED
@@ -220,8 +220,6 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
220
220
|
def __init__(
|
221
221
|
self,
|
222
222
|
net: ContinuousTransformerWrapper,
|
223
|
-
ignore_index = -100,
|
224
|
-
pad_value = 0,
|
225
223
|
loss_fn: Module | None = None,
|
226
224
|
equal_loss_weight_batch = False # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
|
227
225
|
):
|
x_transformers/x_transformers.py
CHANGED
@@ -1207,7 +1207,9 @@ class FeedForward(Module):
|
|
1207
1207
|
dropout = 0.,
|
1208
1208
|
sublayer_dropout = 0.,
|
1209
1209
|
no_bias = False,
|
1210
|
-
zero_init_output = False
|
1210
|
+
zero_init_output = False,
|
1211
|
+
deep_embed_hiddens = False,
|
1212
|
+
deep_embed_num_tokens = None,
|
1211
1213
|
):
|
1212
1214
|
super().__init__()
|
1213
1215
|
inner_dim = int(dim * mult)
|
@@ -1223,27 +1225,51 @@ class FeedForward(Module):
|
|
1223
1225
|
activation = nn.GELU()
|
1224
1226
|
|
1225
1227
|
if glu:
|
1226
|
-
|
1228
|
+
proj_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
|
1227
1229
|
else:
|
1228
|
-
|
1230
|
+
proj_in = nn.Sequential(
|
1229
1231
|
nn.Linear(dim, inner_dim, bias = not no_bias),
|
1230
1232
|
activation
|
1231
1233
|
)
|
1232
1234
|
|
1235
|
+
proj_out = nn.Linear(inner_dim, dim_out, bias = not no_bias)
|
1236
|
+
|
1233
1237
|
self.ff = Sequential(
|
1234
|
-
|
1238
|
+
proj_in,
|
1235
1239
|
LayerNorm(inner_dim) if post_act_ln else None,
|
1236
1240
|
nn.Dropout(dropout),
|
1237
|
-
|
1241
|
+
proj_out,
|
1238
1242
|
nn.Dropout(sublayer_dropout) if sublayer_dropout > 0. else None
|
1239
1243
|
)
|
1240
1244
|
|
1245
|
+
# deep embed
|
1246
|
+
|
1247
|
+
# credit goes to Braden Koszarsky for first devising value embeddings in nanogpt-speedrun project
|
1248
|
+
# then Bo Peng for coming up with this alternate design in feedforward for RWKV 8
|
1249
|
+
# improvements were clearest to me (on my toy setup) with multiplying on output of feedforward, will try with attention at future date
|
1250
|
+
|
1251
|
+
self.deep_embed = None
|
1252
|
+
if deep_embed_hiddens:
|
1253
|
+
assert exists(deep_embed_num_tokens)
|
1254
|
+
self.deep_embed = nn.Parameter(torch.ones(deep_embed_num_tokens, dim_out))
|
1255
|
+
|
1241
1256
|
# init last linear layer to 0
|
1257
|
+
|
1242
1258
|
if zero_init_output:
|
1243
|
-
init_zero_(
|
1259
|
+
init_zero_(proj_out)
|
1244
1260
|
|
1245
|
-
def forward(
|
1246
|
-
|
1261
|
+
def forward(
|
1262
|
+
self,
|
1263
|
+
x,
|
1264
|
+
deep_embed_ids = None
|
1265
|
+
):
|
1266
|
+
out = self.ff(x)
|
1267
|
+
|
1268
|
+
if exists(deep_embed_ids) and exists(self.deep_embed):
|
1269
|
+
deep_embed = self.deep_embed[deep_embed_ids]
|
1270
|
+
out = out * deep_embed
|
1271
|
+
|
1272
|
+
return out
|
1247
1273
|
|
1248
1274
|
# attention. it is all we need
|
1249
1275
|
|
@@ -2354,6 +2380,7 @@ class AttentionLayers(Module):
|
|
2354
2380
|
pos = None,
|
2355
2381
|
context_pos = None,
|
2356
2382
|
attn_bias = None,
|
2383
|
+
deep_embed_ids = None,
|
2357
2384
|
condition = None,
|
2358
2385
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
2359
2386
|
layers_execute_order: tuple[int, ...] | None = None
|
@@ -2448,6 +2475,9 @@ class AttentionLayers(Module):
|
|
2448
2475
|
if cache_age > 0:
|
2449
2476
|
x = x[:, -cache_age:] # for spec decoding, may be greater than 1
|
2450
2477
|
|
2478
|
+
if exists(deep_embed_ids):
|
2479
|
+
deep_embed_ids = deep_embed_ids[:, -cache_age:]
|
2480
|
+
|
2451
2481
|
attn_cache = cache.attn_intermediates
|
2452
2482
|
|
2453
2483
|
iter_attn_cache = iter(attn_cache)
|
@@ -2572,7 +2602,7 @@ class AttentionLayers(Module):
|
|
2572
2602
|
elif layer_type == 'c':
|
2573
2603
|
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
2574
2604
|
elif layer_type == 'f':
|
2575
|
-
out = block(x)
|
2605
|
+
out = block(x, deep_embed_ids = deep_embed_ids)
|
2576
2606
|
|
2577
2607
|
# store first self or cross attention intermediate for value residual
|
2578
2608
|
|
@@ -2959,7 +2989,7 @@ class TransformerWrapper(Module):
|
|
2959
2989
|
|
2960
2990
|
# shapes and variables
|
2961
2991
|
|
2962
|
-
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
|
2992
|
+
b, n, device, token_ids, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, x, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
|
2963
2993
|
|
2964
2994
|
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
|
2965
2995
|
return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
|
@@ -3066,7 +3096,7 @@ class TransformerWrapper(Module):
|
|
3066
3096
|
|
3067
3097
|
# regular
|
3068
3098
|
|
3069
|
-
attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
3099
|
+
attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, deep_embed_ids = token_ids, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
3070
3100
|
|
3071
3101
|
else:
|
3072
3102
|
# recycling
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.3.
|
3
|
+
Version: 2.3.7
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2475,4 +2475,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2475
2475
|
}
|
2476
2476
|
```
|
2477
2477
|
|
2478
|
+
```bibtex
|
2479
|
+
@misc{Jordan2024,
|
2480
|
+
author = {Keller Jordan and Braden Koszarsky},
|
2481
|
+
title = {modded-nanogpt (value embeddings from nanogpt speedrun)},
|
2482
|
+
year = {2024},
|
2483
|
+
publisher = {GitHub},
|
2484
|
+
journal = {GitHub repository},
|
2485
|
+
howpublished = {https://github.com/KellerJordan/modded-nanogpt},
|
2486
|
+
}
|
2487
|
+
```
|
2488
|
+
|
2478
2489
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -2,16 +2,16 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
|
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=DWYD7wwVp0UU5UswK_6CKA_Cmpbl7XfzR9IKMxtECLM,9460
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=bwuZlvse3kYpD0EoHM9gWyi5IpXkF-jsNgQqJGjfRzs,113501
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.7.dist-info/METADATA,sha256=U80x0At1b-5MP3co5wM7CdL7zmtTHWjbYzXn5ypEBoU,89021
|
15
|
+
x_transformers-2.3.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.7.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|