x-transformers 2.1.34__tar.gz → 2.1.35__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.1.34 → x_transformers-2.1.35}/PKG-INFO +1 -1
- {x_transformers-2.1.34 → x_transformers-2.1.35}/pyproject.toml +1 -1
- {x_transformers-2.1.34 → x_transformers-2.1.35}/tests/test_x_transformers.py +1 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/train_belief_state.py +3 -5
- {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/belief_state_wrapper.py +21 -19
- {x_transformers-2.1.34 → x_transformers-2.1.35}/.github/FUNDING.yml +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/.gitignore +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/LICENSE +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/README.md +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/data/README.md +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/data/enwik8.gz +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/all-attention.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/attention-on-attention.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/deepnorm.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/fcm.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/ffglu.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/flash-attention.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/gate_values.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/gating.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/macaron-1.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/macaron-2.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/memory-transformer.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/normformer.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/pia.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/resi_dual.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/residual_attn.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/rezero.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/rotary.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/sandwich-2.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/sandwich.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/sandwich_norm.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/scalenorm.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/talking-heads.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/topk-attention.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/images/xval.png +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/train_copy.py +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/train_enwik8.py +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/train_length_extrapolate.py +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/train_parity.py +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/__init__.py +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/attend.py +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/continuous.py +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/dpo.py +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/xval.py +0 -0
@@ -106,10 +106,8 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
|
|
106
106
|
model.train()
|
107
107
|
|
108
108
|
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
109
|
-
loss = model(
|
110
|
-
|
111
|
-
loss_scale = 1./ GRADIENT_ACCUMULATE_EVERY
|
112
|
-
)
|
109
|
+
loss = model(next(train_loader))
|
110
|
+
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
|
113
111
|
|
114
112
|
print(f'training loss: {loss.item()}')
|
115
113
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
@@ -119,7 +117,7 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
|
|
119
117
|
if i % VALIDATE_EVERY == 0:
|
120
118
|
model.eval()
|
121
119
|
with torch.no_grad():
|
122
|
-
loss = model(next(val_loader)
|
120
|
+
loss = model(next(val_loader))
|
123
121
|
print(f'validation loss: {loss.item()}')
|
124
122
|
|
125
123
|
if i % GENERATE_EVERY == 0:
|
@@ -54,6 +54,26 @@ def flip(x, dim = 1, lens = None):
|
|
54
54
|
|
55
55
|
return x.gather(dim, flip_indices)
|
56
56
|
|
57
|
+
# detach multiple tensors and backward the gradients once
|
58
|
+
|
59
|
+
class DetachMultiple(Function):
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def forward(self, ctx, *tensors):
|
63
|
+
detached_tensors = tuple(t.detach() for t in tensors)
|
64
|
+
|
65
|
+
for detached_tensor in detached_tensors:
|
66
|
+
detached_tensor.requires_grad_()
|
67
|
+
|
68
|
+
return detached_tensors
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def backward(self, ctx, *grads):
|
72
|
+
|
73
|
+
return grads
|
74
|
+
|
75
|
+
detach_multiple = DetachMultiple.apply
|
76
|
+
|
57
77
|
# wrappers
|
58
78
|
|
59
79
|
class BeliefStateWrapper(Module):
|
@@ -268,8 +288,6 @@ class BeliefStateWrapper(Module):
|
|
268
288
|
self,
|
269
289
|
seq,
|
270
290
|
lens: Tensor | None = None, # Int['b']
|
271
|
-
return_loss_only = False,
|
272
|
-
loss_scale = 1.,
|
273
291
|
loss_weight_by_fb_indices: callable | None = None
|
274
292
|
):
|
275
293
|
batch, seq_len, device = *seq.shape, seq.device
|
@@ -302,11 +320,7 @@ class BeliefStateWrapper(Module):
|
|
302
320
|
|
303
321
|
# trick to reduce memory on backwards pass
|
304
322
|
|
305
|
-
|
306
|
-
orig_backward_embeds, backward_embeds = backward_embeds, backward_embeds.detach()
|
307
|
-
|
308
|
-
forward_embeds.requires_grad_()
|
309
|
-
backward_embeds.requires_grad_()
|
323
|
+
forward_embeds, backward_embeds = detach_multiple(forward_embeds, backward_embeds)
|
310
324
|
|
311
325
|
# belief state objective
|
312
326
|
|
@@ -390,11 +404,6 @@ class BeliefStateWrapper(Module):
|
|
390
404
|
pred_dist_loss * self.pred_distance_loss_weight
|
391
405
|
)
|
392
406
|
|
393
|
-
# maybe early return loss
|
394
|
-
|
395
|
-
if return_loss_only:
|
396
|
-
return loss
|
397
|
-
|
398
407
|
# maybe loss weighting
|
399
408
|
|
400
409
|
needs_loss_weight = default(self.needs_loss_weight, exists(loss_weight_by_fb_indices))
|
@@ -420,11 +429,4 @@ class BeliefStateWrapper(Module):
|
|
420
429
|
|
421
430
|
loss = loss.mean()
|
422
431
|
|
423
|
-
# backwards
|
424
|
-
|
425
|
-
(loss * loss_scale).backward()
|
426
|
-
|
427
|
-
orig_forward_embeds.backward(forward_embeds.grad)
|
428
|
-
orig_backward_embeds.backward(backward_embeds.grad)
|
429
|
-
|
430
432
|
return loss
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|