x-transformers 1.41.5__tar.gz → 1.42.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {x_transformers-1.41.5/x_transformers.egg-info → x_transformers-1.42.0}/PKG-INFO +1 -1
  2. {x_transformers-1.41.5 → x_transformers-1.42.0}/README.md +11 -0
  3. {x_transformers-1.41.5 → x_transformers-1.42.0}/setup.py +1 -1
  4. {x_transformers-1.41.5 → x_transformers-1.42.0}/tests/test_x_transformers.py +20 -1
  5. {x_transformers-1.41.5 → x_transformers-1.42.0}/x_transformers/__init__.py +4 -0
  6. x_transformers-1.42.0/x_transformers/neo_mlp.py +126 -0
  7. {x_transformers-1.41.5 → x_transformers-1.42.0/x_transformers.egg-info}/PKG-INFO +1 -1
  8. {x_transformers-1.41.5 → x_transformers-1.42.0}/x_transformers.egg-info/SOURCES.txt +1 -0
  9. {x_transformers-1.41.5 → x_transformers-1.42.0}/LICENSE +0 -0
  10. {x_transformers-1.41.5 → x_transformers-1.42.0}/setup.cfg +0 -0
  11. {x_transformers-1.41.5 → x_transformers-1.42.0}/x_transformers/attend.py +0 -0
  12. {x_transformers-1.41.5 → x_transformers-1.42.0}/x_transformers/autoregressive_wrapper.py +0 -0
  13. {x_transformers-1.41.5 → x_transformers-1.42.0}/x_transformers/continuous.py +0 -0
  14. {x_transformers-1.41.5 → x_transformers-1.42.0}/x_transformers/dpo.py +0 -0
  15. {x_transformers-1.41.5 → x_transformers-1.42.0}/x_transformers/multi_input.py +0 -0
  16. {x_transformers-1.41.5 → x_transformers-1.42.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.41.5 → x_transformers-1.42.0}/x_transformers/x_transformers.py +0 -0
  18. {x_transformers-1.41.5 → x_transformers-1.42.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  19. {x_transformers-1.41.5 → x_transformers-1.42.0}/x_transformers/xval.py +0 -0
  20. {x_transformers-1.41.5 → x_transformers-1.42.0}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.41.5 → x_transformers-1.42.0}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.41.5 → x_transformers-1.42.0}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.41.5
3
+ Version: 1.42.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -2341,4 +2341,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2341
2341
  }
2342
2342
  ```
2343
2343
 
2344
+ ```bibtex
2345
+ @inproceedings{anonymous2024from,
2346
+ title = {From {MLP} to Neo{MLP}: Leveraging Self-Attention for Neural Fields},
2347
+ author = {Anonymous},
2348
+ booktitle = {Submitted to The Thirteenth International Conference on Learning Representations},
2349
+ year = {2024},
2350
+ url = {https://openreview.net/forum?id=A8Vuf2e8y6},
2351
+ note = {under review}
2352
+ }
2353
+ ```
2354
+
2344
2355
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.41.5',
6
+ version = '1.42.0',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -6,7 +6,11 @@ from x_transformers.x_transformers import (
6
6
  TransformerWrapper,
7
7
  Encoder,
8
8
  Decoder,
9
- AutoregressiveWrapper
9
+ AutoregressiveWrapper,
10
+ )
11
+
12
+ from x_transformers.neo_mlp import (
13
+ NeoMLP
10
14
  )
11
15
 
12
16
  from x_transformers.multi_input import MultiInputTransformerWrapper
@@ -357,3 +361,18 @@ def test_forgetting_transformer():
357
361
  x = torch.randint(0, 20000, (2, 1024))
358
362
 
359
363
  embed = model(x)
364
+
365
+ def test_neo_mlp():
366
+
367
+ mlp = NeoMLP(
368
+ dim_in = 5,
369
+ dim_out = 7,
370
+ dim_hidden = 16,
371
+ depth = 5,
372
+ dim_model = 64,
373
+ )
374
+
375
+ x = torch.randn(3, 5)
376
+
377
+ out = mlp(x)
378
+ assert out.shape == (3, 7)
@@ -32,3 +32,7 @@ from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper
32
32
  from x_transformers.dpo import (
33
33
  DPO
34
34
  )
35
+
36
+ from x_transformers.neo_mlp import (
37
+ NeoMLP
38
+ )
@@ -0,0 +1,126 @@
1
+ from collections import namedtuple
2
+
3
+ import torch
4
+ from torch import nn, tensor, pi, is_tensor
5
+ import torch.nn.functional as F
6
+ from torch.nn import Module, ModuleList
7
+
8
+ from einops import rearrange, repeat, einsum, pack, unpack
9
+
10
+ from x_transformers.x_transformers import (
11
+ Encoder
12
+ )
13
+
14
+ # helpers
15
+
16
+ def exists(v):
17
+ return v is not None
18
+
19
+ def default(v, d):
20
+ return v if exists(v) else d
21
+
22
+ # random fourier
23
+
24
+ class RandomFourierEmbed(Module):
25
+
26
+ def __init__(self, dim):
27
+ super().__init__()
28
+ self.proj = nn.Linear(1, dim)
29
+ self.proj.requires_grad_(False)
30
+
31
+ def forward(
32
+ self,
33
+ times,
34
+ ):
35
+
36
+ times = rearrange(times, '... -> ... 1')
37
+ rand_proj = self.proj(times)
38
+ return torch.cos(2 * pi * rand_proj)
39
+
40
+ # class
41
+
42
+ class NeoMLP(Module):
43
+ """ https://openreview.net/forum?id=A8Vuf2e8y6 """
44
+
45
+ def __init__(
46
+ self,
47
+ *,
48
+ dim_in,
49
+ dim_hidden,
50
+ dim_out,
51
+ dim_model,
52
+ depth,
53
+ encoder_kwargs: dict = dict(
54
+ attn_dim_head = 16,
55
+ heads = 4
56
+ )
57
+ ):
58
+ super().__init__()
59
+
60
+ # input and output embeddings
61
+
62
+ self.input_embed = nn.Parameter(torch.zeros(dim_in, dim_model))
63
+ self.hidden_embed = nn.Parameter(torch.zeros(dim_hidden, dim_model))
64
+ self.output_embed = nn.Parameter(torch.zeros(dim_out, dim_model))
65
+
66
+ nn.init.normal_(self.input_embed, std = 0.02)
67
+ nn.init.normal_(self.hidden_embed, std = 0.02)
68
+ nn.init.normal_(self.output_embed, std = 0.02)
69
+
70
+ # they use random fourier for continuous features
71
+
72
+ self.random_fourier = nn.Sequential(
73
+ RandomFourierEmbed(dim_model),
74
+ nn.Linear(dim_model, dim_model)
75
+ )
76
+
77
+ # hidden dimensions of mlp replaced with nodes with message passing
78
+ # which comes back to self attention as a fully connected graph.
79
+
80
+ self.transformer = Encoder(
81
+ dim = dim_model,
82
+ depth = depth,
83
+ **encoder_kwargs
84
+ )
85
+
86
+ # output
87
+
88
+ self.to_output_weights = nn.Parameter(torch.randn(dim_out, dim_model))
89
+ self.to_output_bias = nn.Parameter(torch.zeros(dim_out))
90
+
91
+ def forward(
92
+ self,
93
+ x,
94
+ return_embeds = False
95
+ ):
96
+ batch = x.shape[0]
97
+
98
+ fouriered_input = self.random_fourier(x)
99
+
100
+ # add fouriered input to the input embedding
101
+
102
+ input_embed = fouriered_input + self.input_embed
103
+
104
+ hidden_embed, output_embed = tuple(repeat(t, '... -> b ...', b = batch) for t in (self.hidden_embed, self.output_embed))
105
+
106
+ # pack all the inputs into one string of tokens for self attention
107
+
108
+ embed, packed_shape = pack([input_embed, hidden_embed, output_embed], 'b * d')
109
+
110
+ # attention is all you need
111
+
112
+ embed = self.transformer(embed)
113
+
114
+ # unpack
115
+
116
+ input_embed, hidden_embed, output_embed = unpack(embed, packed_shape, 'b * d')
117
+
118
+ # project for output
119
+
120
+ output = einsum(output_embed, self.to_output_weights, 'b n d, n d -> b n')
121
+ output = output + self.to_output_bias
122
+
123
+ if not return_embeds:
124
+ return output
125
+
126
+ return output, (input_embed, hidden_embed, output_embed)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.41.5
3
+ Version: 1.42.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -9,6 +9,7 @@ x_transformers/autoregressive_wrapper.py
9
9
  x_transformers/continuous.py
10
10
  x_transformers/dpo.py
11
11
  x_transformers/multi_input.py
12
+ x_transformers/neo_mlp.py
12
13
  x_transformers/nonautoregressive_wrapper.py
13
14
  x_transformers/x_transformers.py
14
15
  x_transformers/xl_autoregressive_wrapper.py
File without changes