x-transformers 2.11.5__tar.gz → 2.11.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- {x_transformers-2.11.5 → x_transformers-2.11.7}/PKG-INFO +1 -1
- {x_transformers-2.11.5 → x_transformers-2.11.7}/pyproject.toml +1 -1
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/free_transformer.py +26 -16
- {x_transformers-2.11.5 → x_transformers-2.11.7}/.github/FUNDING.yml +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/.gitignore +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/LICENSE +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/README.md +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/data/README.md +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/data/enwik8.gz +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/all-attention.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/attention-on-attention.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/deepnorm.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/fcm.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/ffglu.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/flash-attention.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/gate_values.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/gating.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/macaron-1.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/macaron-2.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/memory-transformer.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/normformer.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/pia.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/resi_dual.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/residual_attn.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/rezero.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/rotary.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/sandwich-2.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/sandwich.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/sandwich_norm.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/scalenorm.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/talking-heads.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/topk-attention.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/images/xval.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/tests/test_x_transformers.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/train_belief_state.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/train_copy.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/train_enwik8.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/train_free.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/train_gpt_vae.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/train_length_extrapolate.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/train_parity.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/train_with_muon.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/__init__.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/attend.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/continuous.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/dpo.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/gpt_vae.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.7}/x_transformers/xval.py +0 -0
|
@@ -66,7 +66,6 @@ class BinaryMapper(Module):
|
|
|
66
66
|
|
|
67
67
|
self.bits = bits
|
|
68
68
|
self.num_codes = 2 ** bits
|
|
69
|
-
self.kl_loss_threshold = kl_loss_threshold
|
|
70
69
|
|
|
71
70
|
power_two = 2 ** arange(bits)
|
|
72
71
|
codes = (arange(self.num_codes)[:, None].bitwise_and(power_two) != 0).byte().bool()
|
|
@@ -74,13 +73,20 @@ class BinaryMapper(Module):
|
|
|
74
73
|
self.register_buffer('power_two', power_two, persistent = False)
|
|
75
74
|
self.register_buffer('codes', codes, persistent = False)
|
|
76
75
|
|
|
76
|
+
# aux loss
|
|
77
|
+
|
|
78
|
+
self.kl_loss_threshold = kl_loss_threshold
|
|
79
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
80
|
+
|
|
77
81
|
def forward(
|
|
78
82
|
self,
|
|
79
83
|
logits,
|
|
80
84
|
temperature = 1.,
|
|
81
|
-
straight_through = None
|
|
85
|
+
straight_through = None,
|
|
86
|
+
calc_aux_loss = None
|
|
82
87
|
):
|
|
83
88
|
straight_through = default(straight_through, self.training)
|
|
89
|
+
calc_aux_loss = default(calc_aux_loss, self.training)
|
|
84
90
|
|
|
85
91
|
assert logits.shape[-1] == self.bits, f'logits must have a last dimension of {self.bits}'
|
|
86
92
|
|
|
@@ -95,26 +101,29 @@ class BinaryMapper(Module):
|
|
|
95
101
|
|
|
96
102
|
one_hot = F.one_hot(indices, self.num_codes).float()
|
|
97
103
|
|
|
98
|
-
#
|
|
104
|
+
# maybe calculate aux loss
|
|
105
|
+
|
|
106
|
+
aux_kl_loss = self.zero
|
|
99
107
|
|
|
100
|
-
if
|
|
101
|
-
|
|
108
|
+
if calc_aux_loss:
|
|
109
|
+
# calculate negative entropy
|
|
102
110
|
|
|
103
|
-
|
|
111
|
+
kl_div = self.bits * NAT - binary_entropy(logits)
|
|
112
|
+
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
|
|
104
113
|
|
|
105
|
-
|
|
106
|
-
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
|
|
114
|
+
# maybe straight through
|
|
107
115
|
|
|
108
|
-
|
|
116
|
+
if straight_through:
|
|
117
|
+
# get the soft G for the gradients and do a straight through
|
|
109
118
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
119
|
+
soft_G = (
|
|
120
|
+
einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
|
|
121
|
+
einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
|
|
122
|
+
).exp()
|
|
114
123
|
|
|
115
|
-
|
|
124
|
+
# straight through
|
|
116
125
|
|
|
117
|
-
|
|
126
|
+
one_hot = one_hot + soft_G - soft_G.detach()
|
|
118
127
|
|
|
119
128
|
return one_hot, aux_kl_loss
|
|
120
129
|
|
|
@@ -163,6 +172,7 @@ class FreeTransformer(Module):
|
|
|
163
172
|
cross_attend = True,
|
|
164
173
|
use_rmsnorm = True,
|
|
165
174
|
rotary_pos_emb = True,
|
|
175
|
+
pre_norm_has_final_norm = True,
|
|
166
176
|
**kwargs,
|
|
167
177
|
**enc_kwargs
|
|
168
178
|
)
|
|
@@ -242,7 +252,7 @@ class FreeTransformer(Module):
|
|
|
242
252
|
|
|
243
253
|
bit_logits = self.to_latent_bit_logits(pooled)
|
|
244
254
|
|
|
245
|
-
one_hot_latents, kl_loss = self.binary_mapper(bit_logits,
|
|
255
|
+
one_hot_latents, kl_loss = self.binary_mapper(bit_logits, calc_aux_loss = return_kl_loss)
|
|
246
256
|
|
|
247
257
|
if not return_kl_loss:
|
|
248
258
|
return one_hot_latents
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|