x-transformers 2.3.20__py3-none-any.whl → 2.3.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -140,7 +140,8 @@ class AutoregressiveWrapper(Module):
140
140
  ignore_index = -100,
141
141
  pad_value = 0,
142
142
  mask_prob = 0.,
143
- add_attn_z_loss = False
143
+ add_attn_z_loss = False,
144
+ next_embed_loss_weight = 0.1
144
145
  ):
145
146
  super().__init__()
146
147
  self.pad_value = pad_value
@@ -156,6 +157,10 @@ class AutoregressiveWrapper(Module):
156
157
  # whether to add router z-loss
157
158
  self.add_attn_z_loss = add_attn_z_loss
158
159
 
160
+ # whether to add a continuous loss
161
+ self.add_continuous_pred_head = net.add_continuous_pred_head
162
+ self.next_embed_loss_weight = next_embed_loss_weight
163
+
159
164
  @torch.no_grad()
160
165
  @eval_decorator
161
166
  def generate(
@@ -305,9 +310,9 @@ class AutoregressiveWrapper(Module):
305
310
  return out
306
311
 
307
312
  def forward(self, x, return_outputs = False, **kwargs):
308
- seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss
313
+ seq, ignore_index, add_attn_z_loss, add_next_embed_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss, self.add_continuous_pred_head
309
314
 
310
- inp, target = x[:, :-1], x[:, 1:]
315
+ inp, target = x, x[:, 1:]
311
316
  inp = torch.where(inp == ignore_index, self.pad_value, inp)
312
317
 
313
318
  if self.mask_prob > 0.:
@@ -318,15 +323,29 @@ class AutoregressiveWrapper(Module):
318
323
  mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
319
324
  kwargs.update(self_attn_kv_mask = mask)
320
325
 
321
- logits, cache = self.net(
322
- inp,
326
+ out, cache = self.net(
327
+ x,
323
328
  return_intermediates = True,
324
329
  return_attn_z_loss = add_attn_z_loss,
330
+ return_next_embed_pred = add_next_embed_loss,
325
331
  **kwargs
326
332
  )
327
333
 
334
+ # destruct differently if doing continuous pred
335
+
336
+ if add_next_embed_loss:
337
+ logits, (next_embed_pred, init_embeds) = out
338
+ else:
339
+ logits = out
340
+
341
+ logits = logits[:, :-1]
342
+
343
+ # loss function
344
+
328
345
  loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
329
346
 
347
+ # cross entropy loss
348
+
330
349
  loss = loss_fn(
331
350
  rearrange(logits, 'b n c -> b c n'),
332
351
  target,
@@ -336,6 +355,16 @@ class AutoregressiveWrapper(Module):
336
355
  if add_attn_z_loss:
337
356
  loss = loss + cache.attn_z_loss
338
357
 
358
+ if add_next_embed_loss:
359
+ mask = inp[:, :-1] != ignore_index
360
+ embed_pred = next_embed_pred[:, :-1]
361
+ cont_targets = init_embeds[:, 1:].detach()
362
+
363
+ cont_loss = F.l1_loss(embed_pred, cont_targets, reduction = 'none')
364
+ cont_loss = cont_loss[mask].mean()
365
+
366
+ loss = loss + cont_loss * self.next_embed_loss_weight
367
+
339
368
  if not return_outputs:
340
369
  return loss
341
370
 
@@ -141,6 +141,8 @@ class ContinuousTransformerWrapper(Module):
141
141
  sum_embeds = None,
142
142
  prepend_embeds = None,
143
143
  prepend_mask = None,
144
+ cache: LayerIntermediates | None = None,
145
+ input_not_include_cache = False,
144
146
  seq_start_pos = None,
145
147
  **kwargs
146
148
  ):
@@ -154,10 +156,17 @@ class ContinuousTransformerWrapper(Module):
154
156
 
155
157
  mask = einx.less('j, i -> i j', seq_arange, lens)
156
158
 
159
+ # take care of position embedding offsets in the presence of cache and sequence is less than cache length (not full sequence)
160
+
161
+ seq_pos_offset = 0
162
+
163
+ if exists(cache) and input_not_include_cache:
164
+ seq_pos_offset = cache.cache_length
165
+
157
166
  # project in + positional embedding
158
167
 
159
168
  x = self.project_in(x)
160
- x = x + self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos)
169
+ x = x + self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos, offset = seq_pos_offset)
161
170
 
162
171
  if exists(sum_embeds):
163
172
  x = x + sum_embeds
@@ -193,7 +202,7 @@ class ContinuousTransformerWrapper(Module):
193
202
 
194
203
  # attention layers
195
204
 
196
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, return_hiddens = True, **kwargs)
205
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, input_not_include_cache = input_not_include_cache, seq_pos_offset = seq_pos_offset, return_hiddens = True, **kwargs)
197
206
 
198
207
  # splice out memory tokens
199
208
 
@@ -2855,6 +2855,7 @@ class TransformerWrapper(Module):
2855
2855
  sigsoftmax_logits = False,
2856
2856
  ff_deep_embed = False,
2857
2857
  to_logits: Module | None = None,
2858
+ add_continuous_pred_head = False
2858
2859
  ):
2859
2860
  super().__init__()
2860
2861
 
@@ -2975,6 +2976,18 @@ class TransformerWrapper(Module):
2975
2976
  else:
2976
2977
  self.to_logits = LinearNoBias(dim, logits_dim) if not exists(to_logits) else to_logits
2977
2978
 
2979
+ # add a head that predicts the embedding of the next step
2980
+
2981
+ self.add_continuous_pred_head = add_continuous_pred_head
2982
+
2983
+ if add_continuous_pred_head:
2984
+
2985
+ self.to_next_embed_pred = nn.Sequential(
2986
+ LinearNoBias(dim, dim),
2987
+ nn.SiLU(),
2988
+ LinearNoBias(dim, dim)
2989
+ )
2990
+
2978
2991
  # memory tokens (like [cls]) from Memory Transformers paper
2979
2992
 
2980
2993
  num_memory_tokens = default(num_memory_tokens, 0)
@@ -3009,6 +3022,7 @@ class TransformerWrapper(Module):
3009
3022
  return_intermediates = False,
3010
3023
  return_embeddings_and_intermediates = False,
3011
3024
  return_logit_entropies = False,
3025
+ return_next_embed_pred = False,
3012
3026
  mask = None,
3013
3027
  return_mems = False,
3014
3028
  return_attn = False,
@@ -3100,6 +3114,10 @@ class TransformerWrapper(Module):
3100
3114
  assert emb_frac_gradient > 0
3101
3115
  x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
3102
3116
 
3117
+ # init embed
3118
+
3119
+ init_embed = x
3120
+
3103
3121
  # embedding dropout
3104
3122
 
3105
3123
  x = self.emb_dropout(x)
@@ -3261,6 +3279,14 @@ class TransformerWrapper(Module):
3261
3279
  else:
3262
3280
  out = logits
3263
3281
 
3282
+ # maybe next embed pred
3283
+
3284
+ if return_next_embed_pred:
3285
+ assert self.add_continuous_pred_head
3286
+ next_embed_out = self.to_next_embed_pred(x)
3287
+
3288
+ out = (out, (next_embed_out, init_embed))
3289
+
3264
3290
  # logit entropies
3265
3291
 
3266
3292
  if return_logit_entropies:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.20
3
+ Version: 2.3.22
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,17 +1,17 @@
1
1
  x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
2
2
  x_transformers/attend.py,sha256=fXMuwHuBAFB4f4_U6j5_uVeK7N4cV0PDd6UTqtkjKKM,17333
3
- x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
3
+ x_transformers/autoregressive_wrapper.py,sha256=BWFaO-3YWzCcEfp-EC1ZkdckqDpPIOQG6_uyyP6AmhM,11753
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
- x_transformers/continuous.py,sha256=CHta8vizKl85n220fv5278fwjSU-vrN_FBy-m831_go,12551
5
+ x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/x_transformers.py,sha256=l2p-r0iJNlYHUB3vM4lb6ptzNCx9HgA7UfgieEcQT6w,115521
11
+ x_transformers/x_transformers.py,sha256=7phSZvP1_SDRIkVMwVR4cz1dFU2UlR2Wf1HJHEQlcQg,116222
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.20.dist-info/METADATA,sha256=ygWyfnlIh2Mw6bd12gJjjZJyM9vfnXmvvOLyrd2El2k,89897
15
- x_transformers-2.3.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.20.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.20.dist-info/RECORD,,
14
+ x_transformers-2.3.22.dist-info/METADATA,sha256=_8m8ftpHRKjbEUDuoeYPcVh4yan1FxNRj3seJwiZzl8,89897
15
+ x_transformers-2.3.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.22.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.22.dist-info/RECORD,,