x-transformers 1.42.0__py3-none-any.whl → 1.42.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/neo_mlp.py +9 -0
- x_transformers/x_transformers.py +1 -1
- {x_transformers-1.42.0.dist-info → x_transformers-1.42.3.dist-info}/METADATA +1 -1
- {x_transformers-1.42.0.dist-info → x_transformers-1.42.3.dist-info}/RECORD +7 -7
- {x_transformers-1.42.0.dist-info → x_transformers-1.42.3.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.0.dist-info → x_transformers-1.42.3.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.0.dist-info → x_transformers-1.42.3.dist-info}/top_level.txt +0 -0
x_transformers/neo_mlp.py
CHANGED
@@ -41,6 +41,7 @@ class RandomFourierEmbed(Module):
|
|
41
41
|
|
42
42
|
class NeoMLP(Module):
|
43
43
|
""" https://openreview.net/forum?id=A8Vuf2e8y6 """
|
44
|
+
""" https://haian-jin.github.io/projects/LVSM/ """
|
44
45
|
|
45
46
|
def __init__(
|
46
47
|
self,
|
@@ -93,6 +94,11 @@ class NeoMLP(Module):
|
|
93
94
|
x,
|
94
95
|
return_embeds = False
|
95
96
|
):
|
97
|
+
no_batch = x.ndim == 1
|
98
|
+
|
99
|
+
if no_batch:
|
100
|
+
x = rearrange(x, '... -> 1 ...')
|
101
|
+
|
96
102
|
batch = x.shape[0]
|
97
103
|
|
98
104
|
fouriered_input = self.random_fourier(x)
|
@@ -120,6 +126,9 @@ class NeoMLP(Module):
|
|
120
126
|
output = einsum(output_embed, self.to_output_weights, 'b n d, n d -> b n')
|
121
127
|
output = output + self.to_output_bias
|
122
128
|
|
129
|
+
if no_batch:
|
130
|
+
output = rearrange(output, '1 ... -> ...')
|
131
|
+
|
123
132
|
if not return_embeds:
|
124
133
|
return output
|
125
134
|
|
x_transformers/x_transformers.py
CHANGED
@@ -1371,7 +1371,7 @@ class Attention(Module):
|
|
1371
1371
|
diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
|
1372
1372
|
else:
|
1373
1373
|
# https://arxiv.org/abs/2410.17897v1
|
1374
|
-
v = v + value_residual
|
1374
|
+
v = 0.5 * (v + value_residual)
|
1375
1375
|
|
1376
1376
|
# attention is all we need
|
1377
1377
|
|
@@ -4,13 +4,13 @@ x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6am
|
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
|
-
x_transformers/neo_mlp.py,sha256=
|
7
|
+
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=K8zt6n7aC8iMMrWJ-0ryNsJeq7eXR94ci5DXAElC-lY,91995
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.3.dist-info/METADATA,sha256=A9f9ByY6e62AuBpOEMBndE_b_8TOovaau75D2fo73H4,689
|
14
|
+
x_transformers-1.42.3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
15
|
+
x_transformers-1.42.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|