x-transformers 2.1.24__tar.gz → 2.1.26__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.1.24 → x_transformers-2.1.26}/PKG-INFO +1 -1
- {x_transformers-2.1.24 → x_transformers-2.1.26}/pyproject.toml +1 -1
- {x_transformers-2.1.24 → x_transformers-2.1.26}/tests/test_x_transformers.py +5 -2
- {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/belief_state_wrapper.py +18 -15
- {x_transformers-2.1.24 → x_transformers-2.1.26}/.github/FUNDING.yml +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/.gitignore +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/LICENSE +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/README.md +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/data/README.md +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/data/enwik8.gz +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/all-attention.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/attention-on-attention.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/deepnorm.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/fcm.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/ffglu.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/flash-attention.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/gate_values.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/gating.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/macaron-1.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/macaron-2.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/memory-transformer.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/normformer.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/pia.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/resi_dual.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/residual_attn.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/rezero.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/rotary.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/sandwich-2.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/sandwich.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/sandwich_norm.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/scalenorm.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/talking-heads.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/topk-attention.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/images/xval.png +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/train_belief_state.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/train_copy.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/train_enwik8.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/train_length_extrapolate.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/train_parity.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/__init__.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/attend.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/continuous.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/dpo.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/xval.py +0 -0
@@ -696,9 +696,11 @@ def test_lime(
|
|
696
696
|
|
697
697
|
@pytest.mark.parametrize('backward_ar_loss_weight', (1., 0.5))
|
698
698
|
@pytest.mark.parametrize('goal_suffix', (False, True))
|
699
|
+
@pytest.mark.parametrize('pred_distance', (False, True))
|
699
700
|
def test_belief_state_wrapper(
|
700
701
|
backward_ar_loss_weight,
|
701
|
-
goal_suffix
|
702
|
+
goal_suffix,
|
703
|
+
pred_distance
|
702
704
|
):
|
703
705
|
from x_transformers.belief_state_wrapper import BeliefStateWrapper
|
704
706
|
|
@@ -727,7 +729,8 @@ def test_belief_state_wrapper(
|
|
727
729
|
model = BeliefStateWrapper(
|
728
730
|
forward_decoder = forward_model,
|
729
731
|
backward_decoder = backward_model,
|
730
|
-
backward_ar_loss_weight = backward_ar_loss_weight
|
732
|
+
backward_ar_loss_weight = backward_ar_loss_weight,
|
733
|
+
pred_distance = pred_distance
|
731
734
|
)
|
732
735
|
|
733
736
|
seq = torch.randint(0, 20000, (2, 16))
|
@@ -48,8 +48,9 @@ class BeliefStateWrapper(Module):
|
|
48
48
|
train_frac_forward_backward_pairs: float = 1.,
|
49
49
|
text_head: Module | None = None,
|
50
50
|
backward_ar_loss_weight: float = 1., # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
|
51
|
-
|
52
|
-
|
51
|
+
pred_distance = False,
|
52
|
+
pred_distance_loss_weight: float = 1.,
|
53
|
+
max_pred_distance = None
|
53
54
|
):
|
54
55
|
super().__init__()
|
55
56
|
backward_decoder = default(backward_decoder, forward_decoder) # if backward decoder not set, use the same transformer, assume it knows how to switch gears based on suffix token
|
@@ -59,6 +60,7 @@ class BeliefStateWrapper(Module):
|
|
59
60
|
|
60
61
|
dim = forward_decoder.emb_dim
|
61
62
|
num_tokens = forward_decoder.num_tokens
|
63
|
+
max_seq_len = forward_decoder.max_seq_len
|
62
64
|
|
63
65
|
self.num_tokens = num_tokens
|
64
66
|
|
@@ -80,14 +82,15 @@ class BeliefStateWrapper(Module):
|
|
80
82
|
|
81
83
|
# predicting terminal state (when suffix and prefix predict the same token)
|
82
84
|
|
83
|
-
self.
|
85
|
+
self.max_pred_distance = default(max_pred_distance, max_seq_len)
|
86
|
+
|
87
|
+
self.to_distance_logits = nn.Sequential(
|
84
88
|
nn.Linear(dim * 2, dim),
|
85
89
|
nn.LeakyReLU(),
|
86
|
-
nn.Linear(dim,
|
87
|
-
|
88
|
-
) if pred_terminal else None
|
90
|
+
nn.Linear(dim, self.max_pred_distance),
|
91
|
+
) if pred_distance else None
|
89
92
|
|
90
|
-
self.
|
93
|
+
self.pred_distance_loss_weight = pred_distance_loss_weight
|
91
94
|
|
92
95
|
# the two decoders, one which is causal forward, the other causal backwards
|
93
96
|
|
@@ -314,20 +317,20 @@ class BeliefStateWrapper(Module):
|
|
314
317
|
|
315
318
|
# maybe predict terminal
|
316
319
|
|
317
|
-
if exists(self.
|
318
|
-
|
320
|
+
if exists(self.to_distance_logits):
|
321
|
+
distance_logits = self.to_distance_logits(fb_embeds)
|
319
322
|
|
320
|
-
|
321
|
-
|
323
|
+
distance_labels = (bi - fi).clamp(max = self.max_pred_distance - 1)
|
324
|
+
distance_labels = repeat(distance_labels, 'n -> b n', b = batch)
|
322
325
|
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
+
pred_dist_loss = F.cross_entropy(
|
327
|
+
rearrange(distance_logits, 'b n l -> b l n'),
|
328
|
+
distance_labels
|
326
329
|
)
|
327
330
|
|
328
331
|
loss = (
|
329
332
|
loss +
|
330
|
-
|
333
|
+
pred_dist_loss * self.pred_distance_loss_weight
|
331
334
|
)
|
332
335
|
|
333
336
|
# maybe early return loss
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|