x-transformers 2.3.5__py3-none-any.whl → 2.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -220,8 +220,6 @@ class ContinuousAutoregressiveWrapper(Module):
220
220
  def __init__(
221
221
  self,
222
222
  net: ContinuousTransformerWrapper,
223
- ignore_index = -100,
224
- pad_value = 0,
225
223
  loss_fn: Module | None = None,
226
224
  equal_loss_weight_batch = False # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
227
225
  ):
@@ -1207,7 +1207,9 @@ class FeedForward(Module):
1207
1207
  dropout = 0.,
1208
1208
  sublayer_dropout = 0.,
1209
1209
  no_bias = False,
1210
- zero_init_output = False
1210
+ zero_init_output = False,
1211
+ deep_embed_hiddens = False,
1212
+ deep_embed_num_tokens = None,
1211
1213
  ):
1212
1214
  super().__init__()
1213
1215
  inner_dim = int(dim * mult)
@@ -1223,27 +1225,51 @@ class FeedForward(Module):
1223
1225
  activation = nn.GELU()
1224
1226
 
1225
1227
  if glu:
1226
- project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
1228
+ proj_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
1227
1229
  else:
1228
- project_in = nn.Sequential(
1230
+ proj_in = nn.Sequential(
1229
1231
  nn.Linear(dim, inner_dim, bias = not no_bias),
1230
1232
  activation
1231
1233
  )
1232
1234
 
1235
+ proj_out = nn.Linear(inner_dim, dim_out, bias = not no_bias)
1236
+
1233
1237
  self.ff = Sequential(
1234
- project_in,
1238
+ proj_in,
1235
1239
  LayerNorm(inner_dim) if post_act_ln else None,
1236
1240
  nn.Dropout(dropout),
1237
- nn.Linear(inner_dim, dim_out, bias = not no_bias),
1241
+ proj_out,
1238
1242
  nn.Dropout(sublayer_dropout) if sublayer_dropout > 0. else None
1239
1243
  )
1240
1244
 
1245
+ # deep embed
1246
+
1247
+ # credit goes to Braden Koszarsky for first devising value embeddings in nanogpt-speedrun project
1248
+ # then Bo Peng for coming up with this alternate design in feedforward for RWKV 8
1249
+ # improvements were clearest to me (on my toy setup) with multiplying on output of feedforward, will try with attention at future date
1250
+
1251
+ self.deep_embed = None
1252
+ if deep_embed_hiddens:
1253
+ assert exists(deep_embed_num_tokens)
1254
+ self.deep_embed = nn.Parameter(torch.zeros(deep_embed_num_tokens, dim_out))
1255
+
1241
1256
  # init last linear layer to 0
1257
+
1242
1258
  if zero_init_output:
1243
- init_zero_(self.ff[-1])
1259
+ init_zero_(proj_out)
1244
1260
 
1245
- def forward(self, x):
1246
- return self.ff(x)
1261
+ def forward(
1262
+ self,
1263
+ x,
1264
+ deep_embed_ids = None
1265
+ ):
1266
+ out = self.ff(x)
1267
+
1268
+ if exists(deep_embed_ids) and exists(self.deep_embed):
1269
+ deep_embed = self.deep_embed[deep_embed_ids] + 1.
1270
+ out = out * deep_embed
1271
+
1272
+ return out
1247
1273
 
1248
1274
  # attention. it is all we need
1249
1275
 
@@ -2354,6 +2380,7 @@ class AttentionLayers(Module):
2354
2380
  pos = None,
2355
2381
  context_pos = None,
2356
2382
  attn_bias = None,
2383
+ deep_embed_ids = None,
2357
2384
  condition = None,
2358
2385
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
2359
2386
  layers_execute_order: tuple[int, ...] | None = None
@@ -2448,6 +2475,9 @@ class AttentionLayers(Module):
2448
2475
  if cache_age > 0:
2449
2476
  x = x[:, -cache_age:] # for spec decoding, may be greater than 1
2450
2477
 
2478
+ if exists(deep_embed_ids):
2479
+ deep_embed_ids = deep_embed_ids[:, -cache_age:]
2480
+
2451
2481
  attn_cache = cache.attn_intermediates
2452
2482
 
2453
2483
  iter_attn_cache = iter(attn_cache)
@@ -2572,7 +2602,7 @@ class AttentionLayers(Module):
2572
2602
  elif layer_type == 'c':
2573
2603
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2574
2604
  elif layer_type == 'f':
2575
- out = block(x)
2605
+ out = block(x, deep_embed_ids = deep_embed_ids)
2576
2606
 
2577
2607
  # store first self or cross attention intermediate for value residual
2578
2608
 
@@ -2959,7 +2989,7 @@ class TransformerWrapper(Module):
2959
2989
 
2960
2990
  # shapes and variables
2961
2991
 
2962
- b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
2992
+ b, n, device, token_ids, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, x, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
2963
2993
 
2964
2994
  return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
2965
2995
  return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
@@ -3066,7 +3096,7 @@ class TransformerWrapper(Module):
3066
3096
 
3067
3097
  # regular
3068
3098
 
3069
- attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
3099
+ attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, deep_embed_ids = token_ids, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
3070
3100
 
3071
3101
  else:
3072
3102
  # recycling
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.5
3
+ Version: 2.3.6
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2475,4 +2475,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2475
2475
  }
2476
2476
  ```
2477
2477
 
2478
+ ```bibtex
2479
+ @misc{Jordan2024,
2480
+ author = {Keller Jordan and Braden Koszarsky},
2481
+ title = {modded-nanogpt (value embeddings from nanogpt speedrun)},
2482
+ year = {2024},
2483
+ publisher = {GitHub},
2484
+ journal = {GitHub repository},
2485
+ howpublished = {https://github.com/KellerJordan/modded-nanogpt},
2486
+ }
2487
+ ```
2488
+
2478
2489
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2,16 +2,16 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
- x_transformers/continuous.py,sha256=bTxwCt_8RlT1-aR2F4R8YOhpjMF-TbpElRbbRiNd6M8,9512
5
+ x_transformers/continuous.py,sha256=DWYD7wwVp0UU5UswK_6CKA_Cmpbl7XfzR9IKMxtECLM,9460
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/x_transformers.py,sha256=MF91aJGr2DOjIGe57uqwgyNxCExBg_tI9z7usAJMxOM,112401
11
+ x_transformers/x_transformers.py,sha256=kZKk80hxV0Pvmx1E745BR7c8YzB-S4u2cZHSMZvpZq8,113507
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
14
- x_transformers-2.3.5.dist-info/METADATA,sha256=wPHqpSgc75F3npfdSNCzro1F6PBlVXabA0oarpvZMHI,88686
15
- x_transformers-2.3.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.5.dist-info/RECORD,,
14
+ x_transformers-2.3.6.dist-info/METADATA,sha256=Z337g7NRRYaKGbBHkKe1UZbIQJeXPk-dtZ4aBiVvSH8,89021
15
+ x_transformers-2.3.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.6.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.6.dist-info/RECORD,,