x-transformers 2.1.34__py3-none-any.whl → 2.1.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state_wrapper.py +21 -19
- {x_transformers-2.1.34.dist-info → x_transformers-2.1.35.dist-info}/METADATA +1 -1
- {x_transformers-2.1.34.dist-info → x_transformers-2.1.35.dist-info}/RECORD +5 -5
- {x_transformers-2.1.34.dist-info → x_transformers-2.1.35.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.34.dist-info → x_transformers-2.1.35.dist-info}/licenses/LICENSE +0 -0
@@ -54,6 +54,26 @@ def flip(x, dim = 1, lens = None):
|
|
54
54
|
|
55
55
|
return x.gather(dim, flip_indices)
|
56
56
|
|
57
|
+
# detach multiple tensors and backward the gradients once
|
58
|
+
|
59
|
+
class DetachMultiple(Function):
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def forward(self, ctx, *tensors):
|
63
|
+
detached_tensors = tuple(t.detach() for t in tensors)
|
64
|
+
|
65
|
+
for detached_tensor in detached_tensors:
|
66
|
+
detached_tensor.requires_grad_()
|
67
|
+
|
68
|
+
return detached_tensors
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def backward(self, ctx, *grads):
|
72
|
+
|
73
|
+
return grads
|
74
|
+
|
75
|
+
detach_multiple = DetachMultiple.apply
|
76
|
+
|
57
77
|
# wrappers
|
58
78
|
|
59
79
|
class BeliefStateWrapper(Module):
|
@@ -268,8 +288,6 @@ class BeliefStateWrapper(Module):
|
|
268
288
|
self,
|
269
289
|
seq,
|
270
290
|
lens: Tensor | None = None, # Int['b']
|
271
|
-
return_loss_only = False,
|
272
|
-
loss_scale = 1.,
|
273
291
|
loss_weight_by_fb_indices: callable | None = None
|
274
292
|
):
|
275
293
|
batch, seq_len, device = *seq.shape, seq.device
|
@@ -302,11 +320,7 @@ class BeliefStateWrapper(Module):
|
|
302
320
|
|
303
321
|
# trick to reduce memory on backwards pass
|
304
322
|
|
305
|
-
|
306
|
-
orig_backward_embeds, backward_embeds = backward_embeds, backward_embeds.detach()
|
307
|
-
|
308
|
-
forward_embeds.requires_grad_()
|
309
|
-
backward_embeds.requires_grad_()
|
323
|
+
forward_embeds, backward_embeds = detach_multiple(forward_embeds, backward_embeds)
|
310
324
|
|
311
325
|
# belief state objective
|
312
326
|
|
@@ -390,11 +404,6 @@ class BeliefStateWrapper(Module):
|
|
390
404
|
pred_dist_loss * self.pred_distance_loss_weight
|
391
405
|
)
|
392
406
|
|
393
|
-
# maybe early return loss
|
394
|
-
|
395
|
-
if return_loss_only:
|
396
|
-
return loss
|
397
|
-
|
398
407
|
# maybe loss weighting
|
399
408
|
|
400
409
|
needs_loss_weight = default(self.needs_loss_weight, exists(loss_weight_by_fb_indices))
|
@@ -420,11 +429,4 @@ class BeliefStateWrapper(Module):
|
|
420
429
|
|
421
430
|
loss = loss.mean()
|
422
431
|
|
423
|
-
# backwards
|
424
|
-
|
425
|
-
(loss * loss_scale).backward()
|
426
|
-
|
427
|
-
orig_forward_embeds.backward(forward_embeds.grad)
|
428
|
-
orig_backward_embeds.backward(backward_embeds.grad)
|
429
|
-
|
430
432
|
return loss
|
@@ -1,7 +1,7 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state_wrapper.py,sha256=
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
10
10
|
x_transformers/x_transformers.py,sha256=oyYk31qhDyt6cCuDeaHBl4XhUo5bfIwmYE_z1E1OpXU,111502
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.35.dist-info/METADATA,sha256=tLbl-c1QtaOphTa1DpdNfh4dXzFwTt9Fvdh94tnwdTs,88161
|
14
|
+
x_transformers-2.1.35.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.35.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.35.dist-info/RECORD,,
|
File without changes
|
File without changes
|