x-transformers 2.1.22__py3-none-any.whl → 2.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -228,7 +228,8 @@ class BeliefStateWrapper(Module):
228
228
  self,
229
229
  seq,
230
230
  return_loss_only = False,
231
- loss_scale = 1.
231
+ loss_scale = 1.,
232
+ loss_weight_by_fb_indices: callable | None = None
232
233
  ):
233
234
  batch, seq_len, device = *seq.shape, seq.device
234
235
 
@@ -336,9 +337,21 @@ class BeliefStateWrapper(Module):
336
337
 
337
338
  # maybe loss weighting
338
339
 
339
- if self.needs_loss_weight:
340
+ needs_loss_weight = default(self.needs_loss_weight, exists(loss_weight_by_fb_indices))
341
+
342
+ if needs_loss_weight:
340
343
  loss = rearrange(loss, 'b (fb n) -> b fb n', fb = 2)
341
- loss = einx.multiply('b fb n, fb', loss, self.loss_weights)
344
+
345
+ if self.needs_loss_weight:
346
+ loss = einx.multiply('b fb n, fb', loss, self.loss_weights)
347
+
348
+ # allow researcher to pass in a function that acts on the the forward backward indices Int['n fb']
349
+ # the reason this may be needed is because the earlier tokens will have more eligible pairs for training, and perhaps this could be normalized
350
+
351
+ if exists(loss_weight_by_fb_indices):
352
+ loss_weight = loss_weight_by_fb_indices(fb_pairs)
353
+ loss = einx.multiply('b fb n, n', loss, loss_weight)
354
+
342
355
  loss = loss.mean()
343
356
 
344
357
  # backwards
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.22
3
+ Version: 2.1.23
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=crj0yaTNmszDYlueGu_plGKpxVg0GKH8Z-B66SQluNs,10550
4
+ x_transformers/belief_state_wrapper.py,sha256=qDtASUlmc-JcUM0u7pALFgjS9aEsis-jgL07gxWnjsg,11198
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.22.dist-info/METADATA,sha256=FBxTg2dObipuXg2cC_ykYpzLF1AHEebrMyRjoAc0Xk4,87875
14
- x_transformers-2.1.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.22.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.22.dist-info/RECORD,,
13
+ x_transformers-2.1.23.dist-info/METADATA,sha256=26PtwpKeLHzR8yI_4ezvQvmuvz-P3uqaYcPK_RVNvcU,87875
14
+ x_transformers-2.1.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.23.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.23.dist-info/RECORD,,