x-transformers 1.41.5__py3-none-any.whl → 1.42.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -32,3 +32,7 @@ from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper
32
32
  from x_transformers.dpo import (
33
33
  DPO
34
34
  )
35
+
36
+ from x_transformers.neo_mlp import (
37
+ NeoMLP
38
+ )
@@ -0,0 +1,126 @@
1
+ from collections import namedtuple
2
+
3
+ import torch
4
+ from torch import nn, tensor, pi, is_tensor
5
+ import torch.nn.functional as F
6
+ from torch.nn import Module, ModuleList
7
+
8
+ from einops import rearrange, repeat, einsum, pack, unpack
9
+
10
+ from x_transformers.x_transformers import (
11
+ Encoder
12
+ )
13
+
14
+ # helpers
15
+
16
+ def exists(v):
17
+ return v is not None
18
+
19
+ def default(v, d):
20
+ return v if exists(v) else d
21
+
22
+ # random fourier
23
+
24
+ class RandomFourierEmbed(Module):
25
+
26
+ def __init__(self, dim):
27
+ super().__init__()
28
+ self.proj = nn.Linear(1, dim)
29
+ self.proj.requires_grad_(False)
30
+
31
+ def forward(
32
+ self,
33
+ times,
34
+ ):
35
+
36
+ times = rearrange(times, '... -> ... 1')
37
+ rand_proj = self.proj(times)
38
+ return torch.cos(2 * pi * rand_proj)
39
+
40
+ # class
41
+
42
+ class NeoMLP(Module):
43
+ """ https://openreview.net/forum?id=A8Vuf2e8y6 """
44
+
45
+ def __init__(
46
+ self,
47
+ *,
48
+ dim_in,
49
+ dim_hidden,
50
+ dim_out,
51
+ dim_model,
52
+ depth,
53
+ encoder_kwargs: dict = dict(
54
+ attn_dim_head = 16,
55
+ heads = 4
56
+ )
57
+ ):
58
+ super().__init__()
59
+
60
+ # input and output embeddings
61
+
62
+ self.input_embed = nn.Parameter(torch.zeros(dim_in, dim_model))
63
+ self.hidden_embed = nn.Parameter(torch.zeros(dim_hidden, dim_model))
64
+ self.output_embed = nn.Parameter(torch.zeros(dim_out, dim_model))
65
+
66
+ nn.init.normal_(self.input_embed, std = 0.02)
67
+ nn.init.normal_(self.hidden_embed, std = 0.02)
68
+ nn.init.normal_(self.output_embed, std = 0.02)
69
+
70
+ # they use random fourier for continuous features
71
+
72
+ self.random_fourier = nn.Sequential(
73
+ RandomFourierEmbed(dim_model),
74
+ nn.Linear(dim_model, dim_model)
75
+ )
76
+
77
+ # hidden dimensions of mlp replaced with nodes with message passing
78
+ # which comes back to self attention as a fully connected graph.
79
+
80
+ self.transformer = Encoder(
81
+ dim = dim_model,
82
+ depth = depth,
83
+ **encoder_kwargs
84
+ )
85
+
86
+ # output
87
+
88
+ self.to_output_weights = nn.Parameter(torch.randn(dim_out, dim_model))
89
+ self.to_output_bias = nn.Parameter(torch.zeros(dim_out))
90
+
91
+ def forward(
92
+ self,
93
+ x,
94
+ return_embeds = False
95
+ ):
96
+ batch = x.shape[0]
97
+
98
+ fouriered_input = self.random_fourier(x)
99
+
100
+ # add fouriered input to the input embedding
101
+
102
+ input_embed = fouriered_input + self.input_embed
103
+
104
+ hidden_embed, output_embed = tuple(repeat(t, '... -> b ...', b = batch) for t in (self.hidden_embed, self.output_embed))
105
+
106
+ # pack all the inputs into one string of tokens for self attention
107
+
108
+ embed, packed_shape = pack([input_embed, hidden_embed, output_embed], 'b * d')
109
+
110
+ # attention is all you need
111
+
112
+ embed = self.transformer(embed)
113
+
114
+ # unpack
115
+
116
+ input_embed, hidden_embed, output_embed = unpack(embed, packed_shape, 'b * d')
117
+
118
+ # project for output
119
+
120
+ output = einsum(output_embed, self.to_output_weights, 'b n d, n d -> b n')
121
+ output = output + self.to_output_bias
122
+
123
+ if not return_embeds:
124
+ return output
125
+
126
+ return output, (input_embed, hidden_embed, output_embed)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.41.5
3
+ Version: 1.42.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,16 @@
1
- x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
1
+ x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
2
2
  x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
+ x_transformers/neo_mlp.py,sha256=nNQCbNM_uxBS_oc4J28BXDVpsvzyFJIbgm9xgEdjL0c,3221
7
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
9
  x_transformers/x_transformers.py,sha256=UhIbFPXjdQsbFBDHVGmV81LGHTD5qwbusDc5kl3F2A4,91987
9
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
- x_transformers-1.41.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.41.5.dist-info/METADATA,sha256=4op6TctcnQVjp6pWaYugV7rxG-5e-pn6wL_qS95d98E,689
13
- x_transformers-1.41.5.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
14
- x_transformers-1.41.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.41.5.dist-info/RECORD,,
12
+ x_transformers-1.42.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.0.dist-info/METADATA,sha256=cwUbK81bkaChT3ZTyqRr7biJx87s394uJAQPn0df_38,689
14
+ x_transformers-1.42.0.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
+ x_transformers-1.42.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.0.dist-info/RECORD,,