x-transformers 1.41.5__py3-none-any.whl → 1.42.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -32,3 +32,7 @@ from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper
32
32
  from x_transformers.dpo import (
33
33
  DPO
34
34
  )
35
+
36
+ from x_transformers.neo_mlp import (
37
+ NeoMLP
38
+ )
@@ -0,0 +1,135 @@
1
+ from collections import namedtuple
2
+
3
+ import torch
4
+ from torch import nn, tensor, pi, is_tensor
5
+ import torch.nn.functional as F
6
+ from torch.nn import Module, ModuleList
7
+
8
+ from einops import rearrange, repeat, einsum, pack, unpack
9
+
10
+ from x_transformers.x_transformers import (
11
+ Encoder
12
+ )
13
+
14
+ # helpers
15
+
16
+ def exists(v):
17
+ return v is not None
18
+
19
+ def default(v, d):
20
+ return v if exists(v) else d
21
+
22
+ # random fourier
23
+
24
+ class RandomFourierEmbed(Module):
25
+
26
+ def __init__(self, dim):
27
+ super().__init__()
28
+ self.proj = nn.Linear(1, dim)
29
+ self.proj.requires_grad_(False)
30
+
31
+ def forward(
32
+ self,
33
+ times,
34
+ ):
35
+
36
+ times = rearrange(times, '... -> ... 1')
37
+ rand_proj = self.proj(times)
38
+ return torch.cos(2 * pi * rand_proj)
39
+
40
+ # class
41
+
42
+ class NeoMLP(Module):
43
+ """ https://openreview.net/forum?id=A8Vuf2e8y6 """
44
+ """ https://haian-jin.github.io/projects/LVSM/ """
45
+
46
+ def __init__(
47
+ self,
48
+ *,
49
+ dim_in,
50
+ dim_hidden,
51
+ dim_out,
52
+ dim_model,
53
+ depth,
54
+ encoder_kwargs: dict = dict(
55
+ attn_dim_head = 16,
56
+ heads = 4
57
+ )
58
+ ):
59
+ super().__init__()
60
+
61
+ # input and output embeddings
62
+
63
+ self.input_embed = nn.Parameter(torch.zeros(dim_in, dim_model))
64
+ self.hidden_embed = nn.Parameter(torch.zeros(dim_hidden, dim_model))
65
+ self.output_embed = nn.Parameter(torch.zeros(dim_out, dim_model))
66
+
67
+ nn.init.normal_(self.input_embed, std = 0.02)
68
+ nn.init.normal_(self.hidden_embed, std = 0.02)
69
+ nn.init.normal_(self.output_embed, std = 0.02)
70
+
71
+ # they use random fourier for continuous features
72
+
73
+ self.random_fourier = nn.Sequential(
74
+ RandomFourierEmbed(dim_model),
75
+ nn.Linear(dim_model, dim_model)
76
+ )
77
+
78
+ # hidden dimensions of mlp replaced with nodes with message passing
79
+ # which comes back to self attention as a fully connected graph.
80
+
81
+ self.transformer = Encoder(
82
+ dim = dim_model,
83
+ depth = depth,
84
+ **encoder_kwargs
85
+ )
86
+
87
+ # output
88
+
89
+ self.to_output_weights = nn.Parameter(torch.randn(dim_out, dim_model))
90
+ self.to_output_bias = nn.Parameter(torch.zeros(dim_out))
91
+
92
+ def forward(
93
+ self,
94
+ x,
95
+ return_embeds = False
96
+ ):
97
+ no_batch = x.ndim == 1
98
+
99
+ if no_batch:
100
+ x = rearrange(x, '... -> 1 ...')
101
+
102
+ batch = x.shape[0]
103
+
104
+ fouriered_input = self.random_fourier(x)
105
+
106
+ # add fouriered input to the input embedding
107
+
108
+ input_embed = fouriered_input + self.input_embed
109
+
110
+ hidden_embed, output_embed = tuple(repeat(t, '... -> b ...', b = batch) for t in (self.hidden_embed, self.output_embed))
111
+
112
+ # pack all the inputs into one string of tokens for self attention
113
+
114
+ embed, packed_shape = pack([input_embed, hidden_embed, output_embed], 'b * d')
115
+
116
+ # attention is all you need
117
+
118
+ embed = self.transformer(embed)
119
+
120
+ # unpack
121
+
122
+ input_embed, hidden_embed, output_embed = unpack(embed, packed_shape, 'b * d')
123
+
124
+ # project for output
125
+
126
+ output = einsum(output_embed, self.to_output_weights, 'b n d, n d -> b n')
127
+ output = output + self.to_output_bias
128
+
129
+ if no_batch:
130
+ output = rearrange(output, '1 ... -> ...')
131
+
132
+ if not return_embeds:
133
+ return output
134
+
135
+ return output, (input_embed, hidden_embed, output_embed)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.41.5
3
+ Version: 1.42.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,16 @@
1
- x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
1
+ x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
2
2
  x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
+ x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
7
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
9
  x_transformers/x_transformers.py,sha256=UhIbFPXjdQsbFBDHVGmV81LGHTD5qwbusDc5kl3F2A4,91987
9
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
- x_transformers-1.41.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.41.5.dist-info/METADATA,sha256=4op6TctcnQVjp6pWaYugV7rxG-5e-pn6wL_qS95d98E,689
13
- x_transformers-1.41.5.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
14
- x_transformers-1.41.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.41.5.dist-info/RECORD,,
12
+ x_transformers-1.42.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.2.dist-info/METADATA,sha256=ebAS0cohNsa_tNNPhpw4D-HjmQq1AHLh7mMJKqOdh6E,689
14
+ x_transformers-1.42.2.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
+ x_transformers-1.42.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.2.dist-info/RECORD,,