x-transformers 2.1.22__py3-none-any.whl → 2.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state_wrapper.py +22 -3
- {x_transformers-2.1.22.dist-info → x_transformers-2.1.24.dist-info}/METADATA +1 -1
- {x_transformers-2.1.22.dist-info → x_transformers-2.1.24.dist-info}/RECORD +5 -5
- {x_transformers-2.1.22.dist-info → x_transformers-2.1.24.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.22.dist-info → x_transformers-2.1.24.dist-info}/licenses/LICENSE +0 -0
@@ -228,7 +228,8 @@ class BeliefStateWrapper(Module):
|
|
228
228
|
self,
|
229
229
|
seq,
|
230
230
|
return_loss_only = False,
|
231
|
-
loss_scale = 1
|
231
|
+
loss_scale = 1.,
|
232
|
+
loss_weight_by_fb_indices: callable | None = None
|
232
233
|
):
|
233
234
|
batch, seq_len, device = *seq.shape, seq.device
|
234
235
|
|
@@ -336,9 +337,27 @@ class BeliefStateWrapper(Module):
|
|
336
337
|
|
337
338
|
# maybe loss weighting
|
338
339
|
|
339
|
-
|
340
|
+
needs_loss_weight = default(self.needs_loss_weight, exists(loss_weight_by_fb_indices))
|
341
|
+
|
342
|
+
if needs_loss_weight:
|
340
343
|
loss = rearrange(loss, 'b (fb n) -> b fb n', fb = 2)
|
341
|
-
|
344
|
+
|
345
|
+
if self.needs_loss_weight:
|
346
|
+
loss = einx.multiply('b fb n, fb', loss, self.loss_weights)
|
347
|
+
|
348
|
+
# allow researcher to pass in a function that acts on the the forward backward indices Int['n fb']
|
349
|
+
# the reason this may be needed is because the earlier tokens will have more eligible pairs for training, and perhaps this could be normalized
|
350
|
+
|
351
|
+
if exists(loss_weight_by_fb_indices):
|
352
|
+
loss_weight = loss_weight_by_fb_indices(fb_pairs)
|
353
|
+
|
354
|
+
if loss_weight.ndim == 1:
|
355
|
+
loss = einx.multiply('b fb n, n', loss, loss_weight)
|
356
|
+
elif loss_weight.ndim == 2:
|
357
|
+
loss = einx.multiply('b fb n, n fb', loss, loss_weight)
|
358
|
+
else:
|
359
|
+
raise ValueError('invalid loss weight dims')
|
360
|
+
|
342
361
|
loss = loss.mean()
|
343
362
|
|
344
363
|
# backwards
|
@@ -1,7 +1,7 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state_wrapper.py,sha256=
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=S_XP0RQpXemGibYLrVHgNAm0I0IqeocgSMlShcwRqG8,11452
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
10
10
|
x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.24.dist-info/METADATA,sha256=sl2x1DgYJP3FMPO4cZiSJYCISkgrNA1flgdeVYFl3GA,87875
|
14
|
+
x_transformers-2.1.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.24.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.24.dist-info/RECORD,,
|
File without changes
|
File without changes
|