x-transformers 2.1.15__py3-none-any.whl → 2.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,6 +24,7 @@ from x_transformers.x_transformers import (
24
24
 
25
25
  import einx
26
26
  from einops import rearrange, repeat
27
+ from einops.layers.torch import Rearrange
27
28
 
28
29
  # helper functions
29
30
 
@@ -55,7 +56,9 @@ class BeliefStateWrapper(Module):
55
56
  backward_decoder: TransformerWrapper | None = None,
56
57
  train_frac_forward_backward_pairs: float = 1.,
57
58
  text_head: Module | None = None,
58
- backward_ar_loss_weight: float = 1. # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
59
+ backward_ar_loss_weight: float = 1., # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
60
+ pred_terminal = False,
61
+ pred_terminal_loss_weight: float = 1.
59
62
  ):
60
63
  super().__init__()
61
64
  backward_decoder = default(backward_decoder, forward_decoder) # if backward decoder not set, use the same transformer, assume it knows how to switch gears based on suffix token
@@ -82,6 +85,17 @@ class BeliefStateWrapper(Module):
82
85
 
83
86
  self.text_head = text_head
84
87
 
88
+ # predicting terminal state (when suffix and prefix predict the same token)
89
+
90
+ self.to_terminal_logit = nn.Sequential(
91
+ nn.Linear(dim * 2, dim),
92
+ nn.LeakyReLU(),
93
+ nn.Linear(dim, 1),
94
+ Rearrange('... 1 -> ...')
95
+ ) if pred_terminal else None
96
+
97
+ self.pred_terminal_loss_weight = pred_terminal_loss_weight
98
+
85
99
  # the two decoders, one which is causal forward, the other causal backwards
86
100
 
87
101
  self.forward_decoder = forward_decoder
@@ -265,22 +279,40 @@ class BeliefStateWrapper(Module):
265
279
 
266
280
  # cross entropy loss
267
281
 
268
- fb_loss = F.cross_entropy(
282
+ loss = F.cross_entropy(
269
283
  rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
270
284
  rearrange(labels, 'b n fb -> b (fb n)'),
271
285
  reduction = 'none' if self.needs_loss_weight else 'mean'
272
286
  )
273
287
 
288
+ # maybe predict terminal
289
+
290
+ if exists(self.to_terminal_logit):
291
+ terminal_logits = self.to_terminal_logit(fb_embeds)
292
+
293
+ labels = ((bi - fi) == 2).float() # distance is exactly 2
294
+ labels = repeat(labels, 'n -> b n', b = batch)
295
+
296
+ is_end_loss = F.binary_cross_entropy_with_logits(
297
+ terminal_logits,
298
+ labels
299
+ )
300
+
301
+ loss = (
302
+ loss +
303
+ is_end_loss * self.pred_terminal_loss_weight
304
+ )
305
+
274
306
  # maybe loss weighting
275
307
 
276
308
  if self.needs_loss_weight:
277
- fb_loss = rearrange(fb_loss, 'b (fb n) -> b fb n', fb = 2)
278
- fb_loss = einx.multiply('b fb n, fb', fb_loss, self.loss_weights)
279
- fb_loss = fb_loss.mean()
309
+ loss = rearrange(loss, 'b (fb n) -> b fb n', fb = 2)
310
+ loss = einx.multiply('b fb n, fb', loss, self.loss_weights)
311
+ loss = loss.mean()
280
312
 
281
313
  # backwards
282
314
 
283
- orig_backward = getattr(fb_loss, 'backward')
315
+ orig_backward = getattr(loss, 'backward')
284
316
 
285
317
  def patched_backward_fn(*args, **kwargs):
286
318
  orig_backward(*args, **kwargs)
@@ -292,6 +324,6 @@ class BeliefStateWrapper(Module):
292
324
  if backward:
293
325
  patched_backward_fn()
294
326
  else:
295
- setattr(fb_loss, 'backward', patched_backward_fn)
327
+ setattr(loss, 'backward', patched_backward_fn)
296
328
 
297
- return fb_loss
329
+ return loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.15
3
+ Version: 2.1.16
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=LclzwJ4FjfRh4b68Y1IJsWsmo2ymffbutnOqfTg-LdM,8854
4
+ x_transformers/belief_state_wrapper.py,sha256=VjUB73yFBWevN6xMc6_1s-Yc58pJv8SDAUUEXwpR-W0,9842
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.15.dist-info/METADATA,sha256=zV-K3eS2O8ld1ph7JqTW31EcqFasoEh625OPXHz2N78,87571
14
- x_transformers-2.1.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.15.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.15.dist-info/RECORD,,
13
+ x_transformers-2.1.16.dist-info/METADATA,sha256=MXaag1fuq1BsAmyh9k8sSRiHpy-jAUQW2Hn1GC53MnQ,87571
14
+ x_transformers-2.1.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.16.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.16.dist-info/RECORD,,