x-transformers 2.3.7__py3-none-any.whl → 2.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1208,8 +1208,6 @@ class FeedForward(Module):
1208
1208
  sublayer_dropout = 0.,
1209
1209
  no_bias = False,
1210
1210
  zero_init_output = False,
1211
- deep_embed_hiddens = False,
1212
- deep_embed_num_tokens = None,
1213
1211
  ):
1214
1212
  super().__init__()
1215
1213
  inner_dim = int(dim * mult)
@@ -1242,17 +1240,6 @@ class FeedForward(Module):
1242
1240
  nn.Dropout(sublayer_dropout) if sublayer_dropout > 0. else None
1243
1241
  )
1244
1242
 
1245
- # deep embed
1246
-
1247
- # credit goes to Braden Koszarsky for first devising value embeddings in nanogpt-speedrun project
1248
- # then Bo Peng for coming up with this alternate design in feedforward for RWKV 8
1249
- # improvements were clearest to me (on my toy setup) with multiplying on output of feedforward, will try with attention at future date
1250
-
1251
- self.deep_embed = None
1252
- if deep_embed_hiddens:
1253
- assert exists(deep_embed_num_tokens)
1254
- self.deep_embed = nn.Parameter(torch.ones(deep_embed_num_tokens, dim_out))
1255
-
1256
1243
  # init last linear layer to 0
1257
1244
 
1258
1245
  if zero_init_output:
@@ -1261,12 +1248,11 @@ class FeedForward(Module):
1261
1248
  def forward(
1262
1249
  self,
1263
1250
  x,
1264
- deep_embed_ids = None
1251
+ deep_embed = None
1265
1252
  ):
1266
1253
  out = self.ff(x)
1267
1254
 
1268
- if exists(deep_embed_ids) and exists(self.deep_embed):
1269
- deep_embed = self.deep_embed[deep_embed_ids]
1255
+ if exists(deep_embed):
1270
1256
  out = out * deep_embed
1271
1257
 
1272
1258
  return out
@@ -2380,7 +2366,7 @@ class AttentionLayers(Module):
2380
2366
  pos = None,
2381
2367
  context_pos = None,
2382
2368
  attn_bias = None,
2383
- deep_embed_ids = None,
2369
+ deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
2384
2370
  condition = None,
2385
2371
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
2386
2372
  layers_execute_order: tuple[int, ...] | None = None
@@ -2475,13 +2461,26 @@ class AttentionLayers(Module):
2475
2461
  if cache_age > 0:
2476
2462
  x = x[:, -cache_age:] # for spec decoding, may be greater than 1
2477
2463
 
2478
- if exists(deep_embed_ids):
2479
- deep_embed_ids = deep_embed_ids[:, -cache_age:]
2464
+ if exists(deep_embeds_and_ids):
2465
+ deep_embeds, token_ids = deep_embeds_and_ids
2466
+ token_ids = token_ids[:, -cache_age:]
2467
+ deep_embeds_and_ids = (deep_embeds, token_ids)
2480
2468
 
2481
2469
  attn_cache = cache.attn_intermediates
2482
2470
 
2483
2471
  iter_attn_cache = iter(attn_cache)
2484
2472
 
2473
+ # handle deep embeds if needed
2474
+
2475
+ deep_embeds = []
2476
+
2477
+ if exists(deep_embeds_and_ids):
2478
+ deep_embeds, token_ids = deep_embeds_and_ids
2479
+ deep_embeds_across_depth = deep_embeds[token_ids]
2480
+ deep_embeds = rearrange(deep_embeds_across_depth, 'b n l d -> l b n d')
2481
+
2482
+ deep_embeds_iter = iter(deep_embeds)
2483
+
2485
2484
  # setup multistreams if needed
2486
2485
 
2487
2486
  streams = self.num_residual_streams
@@ -2602,7 +2601,7 @@ class AttentionLayers(Module):
2602
2601
  elif layer_type == 'c':
2603
2602
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2604
2603
  elif layer_type == 'f':
2605
- out = block(x, deep_embed_ids = deep_embed_ids)
2604
+ out = block(x, deep_embed = next(deep_embeds_iter, None))
2606
2605
 
2607
2606
  # store first self or cross attention intermediate for value residual
2608
2607
 
@@ -2818,11 +2817,14 @@ class TransformerWrapper(Module):
2818
2817
  mixture_of_softmax = False,
2819
2818
  mixture_of_softmax_k = 4,
2820
2819
  sigsoftmax_logits = False,
2820
+ ff_deep_embed = False,
2821
2821
  to_logits: Module | None = None,
2822
2822
  ):
2823
2823
  super().__init__()
2824
2824
 
2825
2825
  dim = attn_layers.dim
2826
+ depth = attn_layers.depth
2827
+
2826
2828
  emb_dim = default(emb_dim, dim)
2827
2829
  self.emb_dim = emb_dim
2828
2830
  self.num_tokens = num_tokens
@@ -2855,6 +2857,16 @@ class TransformerWrapper(Module):
2855
2857
  if len(embed_num_tokens) > 0:
2856
2858
  self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
2857
2859
 
2860
+ # deep embed
2861
+
2862
+ # credit goes to Braden Koszarsky for first devising value embeddings in nanogpt-speedrun project
2863
+ # then Bo Peng for coming up with this alternate design in feedforward for RWKV 8
2864
+ # improvements were clearest to me (on my toy setup) with multiplying on output of feedforward, will try with attention at future date
2865
+
2866
+ self.ff_deep_embed = None
2867
+ if ff_deep_embed:
2868
+ self.ff_deep_embed = nn.Parameter(torch.ones(num_tokens, depth, dim))
2869
+
2858
2870
  # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
2859
2871
 
2860
2872
  self.emb_frac_gradient = emb_frac_gradient
@@ -3050,6 +3062,13 @@ class TransformerWrapper(Module):
3050
3062
 
3051
3063
  x = self.project_emb(x)
3052
3064
 
3065
+ # maybe deep embeds
3066
+
3067
+ deep_embed_and_ids = None
3068
+
3069
+ if exists(self.ff_deep_embed):
3070
+ deep_embed_and_ids = (self.ff_deep_embed, token_ids)
3071
+
3053
3072
  # maybe cls token
3054
3073
 
3055
3074
  if exists(self.cls_token):
@@ -3096,7 +3115,7 @@ class TransformerWrapper(Module):
3096
3115
 
3097
3116
  # regular
3098
3117
 
3099
- attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, deep_embed_ids = token_ids, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
3118
+ attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, deep_embeds_and_ids = deep_embed_and_ids, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
3100
3119
 
3101
3120
  else:
3102
3121
  # recycling
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.7
3
+ Version: 2.3.9
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -8,10 +8,10 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/x_transformers.py,sha256=bwuZlvse3kYpD0EoHM9gWyi5IpXkF-jsNgQqJGjfRzs,113501
11
+ x_transformers/x_transformers.py,sha256=Wo5hauzdn4Q9PUVjBqQo-1vCq08BT2jYUDbq3r2a5Go,114061
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
14
- x_transformers-2.3.7.dist-info/METADATA,sha256=U80x0At1b-5MP3co5wM7CdL7zmtTHWjbYzXn5ypEBoU,89021
15
- x_transformers-2.3.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.7.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.7.dist-info/RECORD,,
14
+ x_transformers-2.3.9.dist-info/METADATA,sha256=M0rUg95el9swG5VFWas-5YjahLSgytT65W9L3Ne7BJM,89021
15
+ x_transformers-2.3.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.9.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.9.dist-info/RECORD,,