x-transformers 2.1.24__py3-none-any.whl → 2.1.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -48,8 +48,9 @@ class BeliefStateWrapper(Module):
48
48
  train_frac_forward_backward_pairs: float = 1.,
49
49
  text_head: Module | None = None,
50
50
  backward_ar_loss_weight: float = 1., # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
51
- pred_terminal = False,
52
- pred_terminal_loss_weight: float = 1.
51
+ pred_distance = False,
52
+ pred_distance_loss_weight: float = 1.,
53
+ max_pred_distance = None
53
54
  ):
54
55
  super().__init__()
55
56
  backward_decoder = default(backward_decoder, forward_decoder) # if backward decoder not set, use the same transformer, assume it knows how to switch gears based on suffix token
@@ -59,6 +60,7 @@ class BeliefStateWrapper(Module):
59
60
 
60
61
  dim = forward_decoder.emb_dim
61
62
  num_tokens = forward_decoder.num_tokens
63
+ max_seq_len = forward_decoder.max_seq_len
62
64
 
63
65
  self.num_tokens = num_tokens
64
66
 
@@ -80,14 +82,15 @@ class BeliefStateWrapper(Module):
80
82
 
81
83
  # predicting terminal state (when suffix and prefix predict the same token)
82
84
 
83
- self.to_terminal_logit = nn.Sequential(
85
+ self.max_pred_distance = default(max_pred_distance, max_seq_len)
86
+
87
+ self.to_distance_logits = nn.Sequential(
84
88
  nn.Linear(dim * 2, dim),
85
89
  nn.LeakyReLU(),
86
- nn.Linear(dim, 1),
87
- Rearrange('... 1 -> ...')
88
- ) if pred_terminal else None
90
+ nn.Linear(dim, self.max_pred_distance),
91
+ ) if pred_distance else None
89
92
 
90
- self.pred_terminal_loss_weight = pred_terminal_loss_weight
93
+ self.pred_distance_loss_weight = pred_distance_loss_weight
91
94
 
92
95
  # the two decoders, one which is causal forward, the other causal backwards
93
96
 
@@ -314,20 +317,20 @@ class BeliefStateWrapper(Module):
314
317
 
315
318
  # maybe predict terminal
316
319
 
317
- if exists(self.to_terminal_logit):
318
- terminal_logits = self.to_terminal_logit(fb_embeds)
320
+ if exists(self.to_distance_logits):
321
+ distance_logits = self.to_distance_logits(fb_embeds)
319
322
 
320
- terminal_labels = ((bi - fi) == 2).float() # distance is exactly 2
321
- terminal_labels = repeat(terminal_labels, 'n -> b n', b = batch)
323
+ distance_labels = bi - fi
324
+ distance_labels = repeat(distance_labels, 'n -> b n', b = batch)
322
325
 
323
- is_end_loss = F.binary_cross_entropy_with_logits(
324
- terminal_logits,
325
- terminal_labels
326
+ pred_dist_loss = F.cross_entropy(
327
+ rearrange(distance_logits, 'b n l -> b l n'),
328
+ distance_labels
326
329
  )
327
330
 
328
331
  loss = (
329
332
  loss +
330
- is_end_loss * self.pred_terminal_loss_weight
333
+ pred_dist_loss * self.pred_distance_loss_weight
331
334
  )
332
335
 
333
336
  # maybe early return loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.24
3
+ Version: 2.1.25
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=S_XP0RQpXemGibYLrVHgNAm0I0IqeocgSMlShcwRqG8,11452
4
+ x_transformers/belief_state_wrapper.py,sha256=xPzhUZYm7qdHYp9fQ73HjwvWmEhve6-cEisvYK5serI,11571
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.24.dist-info/METADATA,sha256=sl2x1DgYJP3FMPO4cZiSJYCISkgrNA1flgdeVYFl3GA,87875
14
- x_transformers-2.1.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.24.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.24.dist-info/RECORD,,
13
+ x_transformers-2.1.25.dist-info/METADATA,sha256=h182tY5ffDUrwUr8VZYekkTeMcDlMD_Krw1XP2K0YWU,87875
14
+ x_transformers-2.1.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.25.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.25.dist-info/RECORD,,