x-transformers 2.1.34__py3-none-any.whl → 2.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state_wrapper.py +21 -19
- x_transformers/x_transformers.py +5 -4
- {x_transformers-2.1.34.dist-info → x_transformers-2.1.36.dist-info}/METADATA +1 -1
- {x_transformers-2.1.34.dist-info → x_transformers-2.1.36.dist-info}/RECORD +6 -6
- {x_transformers-2.1.34.dist-info → x_transformers-2.1.36.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.34.dist-info → x_transformers-2.1.36.dist-info}/licenses/LICENSE +0 -0
@@ -54,6 +54,26 @@ def flip(x, dim = 1, lens = None):
|
|
54
54
|
|
55
55
|
return x.gather(dim, flip_indices)
|
56
56
|
|
57
|
+
# detach multiple tensors and backward the gradients once
|
58
|
+
|
59
|
+
class DetachMultiple(Function):
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def forward(self, ctx, *tensors):
|
63
|
+
detached_tensors = tuple(t.detach() for t in tensors)
|
64
|
+
|
65
|
+
for detached_tensor in detached_tensors:
|
66
|
+
detached_tensor.requires_grad_()
|
67
|
+
|
68
|
+
return detached_tensors
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def backward(self, ctx, *grads):
|
72
|
+
|
73
|
+
return grads
|
74
|
+
|
75
|
+
detach_multiple = DetachMultiple.apply
|
76
|
+
|
57
77
|
# wrappers
|
58
78
|
|
59
79
|
class BeliefStateWrapper(Module):
|
@@ -268,8 +288,6 @@ class BeliefStateWrapper(Module):
|
|
268
288
|
self,
|
269
289
|
seq,
|
270
290
|
lens: Tensor | None = None, # Int['b']
|
271
|
-
return_loss_only = False,
|
272
|
-
loss_scale = 1.,
|
273
291
|
loss_weight_by_fb_indices: callable | None = None
|
274
292
|
):
|
275
293
|
batch, seq_len, device = *seq.shape, seq.device
|
@@ -302,11 +320,7 @@ class BeliefStateWrapper(Module):
|
|
302
320
|
|
303
321
|
# trick to reduce memory on backwards pass
|
304
322
|
|
305
|
-
|
306
|
-
orig_backward_embeds, backward_embeds = backward_embeds, backward_embeds.detach()
|
307
|
-
|
308
|
-
forward_embeds.requires_grad_()
|
309
|
-
backward_embeds.requires_grad_()
|
323
|
+
forward_embeds, backward_embeds = detach_multiple(forward_embeds, backward_embeds)
|
310
324
|
|
311
325
|
# belief state objective
|
312
326
|
|
@@ -390,11 +404,6 @@ class BeliefStateWrapper(Module):
|
|
390
404
|
pred_dist_loss * self.pred_distance_loss_weight
|
391
405
|
)
|
392
406
|
|
393
|
-
# maybe early return loss
|
394
|
-
|
395
|
-
if return_loss_only:
|
396
|
-
return loss
|
397
|
-
|
398
407
|
# maybe loss weighting
|
399
408
|
|
400
409
|
needs_loss_weight = default(self.needs_loss_weight, exists(loss_weight_by_fb_indices))
|
@@ -420,11 +429,4 @@ class BeliefStateWrapper(Module):
|
|
420
429
|
|
421
430
|
loss = loss.mean()
|
422
431
|
|
423
|
-
# backwards
|
424
|
-
|
425
|
-
(loss * loss_scale).backward()
|
426
|
-
|
427
|
-
orig_forward_embeds.backward(forward_embeds.grad)
|
428
|
-
orig_backward_embeds.backward(backward_embeds.grad)
|
429
|
-
|
430
432
|
return loss
|
x_transformers/x_transformers.py
CHANGED
@@ -864,14 +864,15 @@ class DynamicTanh(Module):
|
|
864
864
|
self.gamma = nn.Parameter(torch.ones(dim))
|
865
865
|
self.beta = nn.Parameter(torch.zeros(dim))
|
866
866
|
|
867
|
-
self.
|
867
|
+
self.pre_tanh_scale_offset = init_alpha if unit_offset else 0.
|
868
|
+
self.gamma_offset = float(unit_offset)
|
868
869
|
|
869
|
-
nn.init.constant_(self.pre_tanh_scale,
|
870
|
+
nn.init.constant_(self.pre_tanh_scale, 0 if unit_offset else init_alpha)
|
870
871
|
nn.init.constant_(self.gamma, 1. - float(unit_offset))
|
871
872
|
|
872
873
|
def forward(self, x):
|
873
|
-
pre_tanh_scale = self.pre_tanh_scale + self.
|
874
|
-
gamma = self.gamma + self.
|
874
|
+
pre_tanh_scale = self.pre_tanh_scale + self.pre_tanh_scale_offset
|
875
|
+
gamma = self.gamma + self.gamma_offset
|
875
876
|
return (x * pre_tanh_scale).tanh() * gamma + self.beta
|
876
877
|
|
877
878
|
# residual and residual gates
|
@@ -1,16 +1,16 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state_wrapper.py,sha256=
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
8
8
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
9
9
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
10
|
-
x_transformers/x_transformers.py,sha256=
|
10
|
+
x_transformers/x_transformers.py,sha256=voN-uEBEKxpUu9K4MVcneSTrzdgJWnZGuQ1QRZQw4Q4,111596
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.36.dist-info/METADATA,sha256=D0qdMRucK3PWwEi8WwdiJdZ8X_hGTm1r3_7bJzYiWSM,88161
|
14
|
+
x_transformers-2.1.36.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.36.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.36.dist-info/RECORD,,
|
File without changes
|
File without changes
|