x-transformers 2.1.21__py3-none-any.whl → 2.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -150,6 +150,10 @@ class BeliefStateWrapper(Module):
150
150
 
151
151
  # get the encoded suffix token once
152
152
 
153
+ suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
154
+
155
+ suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
156
+
153
157
  if not decode_backwards:
154
158
  if exists(suffix):
155
159
  if suffix.ndim == 1:
@@ -157,10 +161,6 @@ class BeliefStateWrapper(Module):
157
161
 
158
162
  suffix = suffix.flip(1) # reverse autoregressive
159
163
 
160
- suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
161
-
162
- suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
163
-
164
164
  suffix_embed = self.backward_decoder(
165
165
  suffix,
166
166
  prepend_embeds = suffix_sos_tokens,
@@ -184,6 +184,7 @@ class BeliefStateWrapper(Module):
184
184
 
185
185
  embeds, new_cache = main_decoder(
186
186
  out,
187
+ prepend_embeds = suffix_sos_tokens if decode_backwards else None,
187
188
  return_intermediates = True,
188
189
  return_embeddings = True,
189
190
  cache = cache,
@@ -227,7 +228,8 @@ class BeliefStateWrapper(Module):
227
228
  self,
228
229
  seq,
229
230
  return_loss_only = False,
230
- loss_scale = 1.
231
+ loss_scale = 1.,
232
+ loss_weight_by_fb_indices: callable | None = None
231
233
  ):
232
234
  batch, seq_len, device = *seq.shape, seq.device
233
235
 
@@ -335,9 +337,21 @@ class BeliefStateWrapper(Module):
335
337
 
336
338
  # maybe loss weighting
337
339
 
338
- if self.needs_loss_weight:
340
+ needs_loss_weight = default(self.needs_loss_weight, exists(loss_weight_by_fb_indices))
341
+
342
+ if needs_loss_weight:
339
343
  loss = rearrange(loss, 'b (fb n) -> b fb n', fb = 2)
340
- loss = einx.multiply('b fb n, fb', loss, self.loss_weights)
344
+
345
+ if self.needs_loss_weight:
346
+ loss = einx.multiply('b fb n, fb', loss, self.loss_weights)
347
+
348
+ # allow researcher to pass in a function that acts on the the forward backward indices Int['n fb']
349
+ # the reason this may be needed is because the earlier tokens will have more eligible pairs for training, and perhaps this could be normalized
350
+
351
+ if exists(loss_weight_by_fb_indices):
352
+ loss_weight = loss_weight_by_fb_indices(fb_pairs)
353
+ loss = einx.multiply('b fb n, n', loss, loss_weight)
354
+
341
355
  loss = loss.mean()
342
356
 
343
357
  # backwards
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.21
3
+ Version: 2.1.23
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2444,4 +2444,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2444
2444
  }
2445
2445
  ```
2446
2446
 
2447
+ ```bibtex
2448
+ @article{Charpentier2024GPTOB,
2449
+ title = {GPT or BERT: why not both?},
2450
+ author = {Lucas Georges Gabriel Charpentier and David Samuel},
2451
+ journal = {ArXiv},
2452
+ year = {2024},
2453
+ volume = {abs/2410.24159},
2454
+ url = {https://api.semanticscholar.org/CorpusID:273707069}
2455
+ }
2456
+ ```
2457
+
2447
2458
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=vQUg5djN8TJVhofhPhtMpMbbFz6d1lGsQOukLZhsa3I,10476
4
+ x_transformers/belief_state_wrapper.py,sha256=qDtASUlmc-JcUM0u7pALFgjS9aEsis-jgL07gxWnjsg,11198
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.21.dist-info/METADATA,sha256=kTobyUmA8d0yo8NuwHSrpRchyfBMGZZAa2b0Ly3hZec,87571
14
- x_transformers-2.1.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.21.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.21.dist-info/RECORD,,
13
+ x_transformers-2.1.23.dist-info/METADATA,sha256=26PtwpKeLHzR8yI_4ezvQvmuvz-P3uqaYcPK_RVNvcU,87875
14
+ x_transformers-2.1.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.23.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.23.dist-info/RECORD,,