x-transformers 1.40.3__tar.gz → 1.40.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-1.40.3/x_transformers.egg-info → x_transformers-1.40.4}/PKG-INFO +1 -1
- {x_transformers-1.40.3 → x_transformers-1.40.4}/README.md +11 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/setup.py +1 -1
- {x_transformers-1.40.3 → x_transformers-1.40.4}/tests/test_x_transformers.py +18 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/x_transformers.py +14 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.40.3 → x_transformers-1.40.4}/LICENSE +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/setup.cfg +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/__init__.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/attend.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/continuous.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/dpo.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/xval.py +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers.egg-info/top_level.txt +0 -0
@@ -2287,4 +2287,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2287
2287
|
}
|
2288
2288
|
```
|
2289
2289
|
|
2290
|
+
```bibtex
|
2291
|
+
@article{Bai2019DeepEM,
|
2292
|
+
title = {Deep Equilibrium Models},
|
2293
|
+
author = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun},
|
2294
|
+
journal = {ArXiv},
|
2295
|
+
year = {2019},
|
2296
|
+
volume = {abs/1909.01377},
|
2297
|
+
url = {https://api.semanticscholar.org/CorpusID:202539738}
|
2298
|
+
}
|
2299
|
+
```
|
2300
|
+
|
2290
2301
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -298,3 +298,21 @@ def test_l2_distance(attn_one_kv_head):
|
|
298
298
|
x = torch.randint(0, 256, (1, 1024))
|
299
299
|
|
300
300
|
model(x)
|
301
|
+
|
302
|
+
def test_reinject_input():
|
303
|
+
|
304
|
+
model = TransformerWrapper(
|
305
|
+
num_tokens = 20000,
|
306
|
+
max_seq_len = 1024,
|
307
|
+
recycling = True,
|
308
|
+
attn_layers = Decoder(
|
309
|
+
dim = 512,
|
310
|
+
depth = 12,
|
311
|
+
heads = 8,
|
312
|
+
reinject_input = True
|
313
|
+
)
|
314
|
+
)
|
315
|
+
|
316
|
+
x = torch.randint(0, 256, (1, 12))
|
317
|
+
|
318
|
+
model(x) # (1, 1024, 20000)
|
@@ -1353,6 +1353,7 @@ class AttentionLayers(Module):
|
|
1353
1353
|
use_layerscale = False,
|
1354
1354
|
layerscale_init_value = 0.,
|
1355
1355
|
unet_skips = False,
|
1356
|
+
reinject_input = True, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1356
1357
|
**kwargs
|
1357
1358
|
):
|
1358
1359
|
super().__init__()
|
@@ -1582,6 +1583,11 @@ class AttentionLayers(Module):
|
|
1582
1583
|
|
1583
1584
|
self.skip_combines = ModuleList([])
|
1584
1585
|
|
1586
|
+
# whether there is reinjection of input at every layer
|
1587
|
+
|
1588
|
+
self.reinject_input = reinject_input
|
1589
|
+
self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
|
1590
|
+
|
1585
1591
|
# iterate and construct layers
|
1586
1592
|
|
1587
1593
|
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
@@ -1766,6 +1772,11 @@ class AttentionLayers(Module):
|
|
1766
1772
|
|
1767
1773
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
1768
1774
|
|
1775
|
+
# derived input for reinjection if needed
|
1776
|
+
|
1777
|
+
if self.reinject_input:
|
1778
|
+
inp_inject = self.reinject_input_proj(x)
|
1779
|
+
|
1769
1780
|
# store all hiddens for skips
|
1770
1781
|
|
1771
1782
|
skip_hiddens = []
|
@@ -1810,6 +1821,9 @@ class AttentionLayers(Module):
|
|
1810
1821
|
post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
|
1811
1822
|
post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
|
1812
1823
|
|
1824
|
+
if self.reinject_input:
|
1825
|
+
x = x + inp_inject
|
1826
|
+
|
1813
1827
|
if exists(pre_norm):
|
1814
1828
|
x = pre_norm(x)
|
1815
1829
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|