x-transformers 2.2.12__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,13 @@
1
+ from __future__ import annotations
2
+
1
3
  import torch
2
- from torch import nn
4
+ from torch import nn, cat, stack
5
+ from torch.nn import Module
3
6
  import torch.nn.functional as F
7
+ from torch.distributions import Normal
4
8
 
5
9
  import einx
6
- from einops import reduce, pack, repeat, unpack
10
+ from einops import rearrange, reduce, pack, repeat, unpack
7
11
 
8
12
  from x_transformers.x_transformers import (
9
13
  AttentionLayers,
@@ -23,7 +27,7 @@ def exists(val):
23
27
  def default(val, d):
24
28
  if exists(val):
25
29
  return val
26
- return d() if callable(d) else d
30
+ return d() if not isinstance(d, Module) and callable(d) else d
27
31
 
28
32
  def masked_mean(t, mask):
29
33
  t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
@@ -34,9 +38,17 @@ def masked_mean(t, mask):
34
38
  masked_average = num / den.clamp(min = 1.)
35
39
  return masked_average
36
40
 
41
+ # probabilistic loss fn
42
+
43
+ class GaussianNLL(Module):
44
+ def forward(self, pred, target):
45
+ mean, var = pred
46
+ dist = Normal(mean, var)
47
+ return -dist.log_prob(target)
48
+
37
49
  # main classes
38
50
 
39
- class ContinuousTransformerWrapper(nn.Module):
51
+ class ContinuousTransformerWrapper(Module):
40
52
  def __init__(
41
53
  self,
42
54
  *,
@@ -51,7 +63,8 @@ class ContinuousTransformerWrapper(nn.Module):
51
63
  emb_dropout = 0.,
52
64
  use_abs_pos_emb = True,
53
65
  scaled_sinu_pos_emb = False,
54
- average_pool_embed = False
66
+ average_pool_embed = False,
67
+ probabilistic = False
55
68
  ):
56
69
  super().__init__()
57
70
  dim = attn_layers.dim
@@ -91,7 +104,12 @@ class ContinuousTransformerWrapper(nn.Module):
91
104
  # project in and out
92
105
 
93
106
  self.project_in = nn.Linear(dim_in, dim, bias = False) if exists(dim_in) else nn.Identity()
94
- self.project_out = nn.Linear(dim, dim_out, bias = False) if exists(dim_out) else nn.Identity()
107
+
108
+ # output is multipled by 2 for outputting mean and log variance
109
+
110
+ self.probabilistic = probabilistic
111
+
112
+ self.project_out = nn.Linear(dim, dim_out * (2 if probabilistic else 1), bias = False) if exists(dim_out) else nn.Identity()
95
113
 
96
114
  def forward(
97
115
  self,
@@ -132,13 +150,13 @@ class ContinuousTransformerWrapper(nn.Module):
132
150
 
133
151
  assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
134
152
 
135
- x = torch.cat((prepend_embeds, x), dim = -2)
153
+ x = cat((prepend_embeds, x), dim = -2)
136
154
 
137
155
  if exists(prepend_mask) or exists(mask):
138
156
  mask = default(mask, lambda: torch.ones((batch, seq), device = device, dtype = torch.bool))
139
157
  prepend_mask = default(prepend_mask, lambda: torch.ones((batch, prepend_seq), device = device, dtype = torch.bool))
140
158
 
141
- mask = torch.cat((prepend_mask, mask), dim = -1)
159
+ mask = cat((prepend_mask, mask), dim = -1)
142
160
 
143
161
  x = self.emb_dropout(x)
144
162
 
@@ -159,6 +177,11 @@ class ContinuousTransformerWrapper(nn.Module):
159
177
 
160
178
  out = self.project_out(x) if not return_embeddings else x
161
179
 
180
+ if not return_embeddings and self.probabilistic:
181
+ mean, log_var = rearrange(out, '... (d mean_log_var) -> mean_log_var ... d', mean_log_var = 2)
182
+ variance = log_var.exp()
183
+ return stack((mean, variance))
184
+
162
185
  if return_intermediates:
163
186
  return out, intermediates
164
187
 
@@ -173,24 +196,34 @@ class ContinuousTransformerWrapper(nn.Module):
173
196
 
174
197
  return out
175
198
 
176
- class ContinuousAutoregressiveWrapper(nn.Module):
199
+ class ContinuousAutoregressiveWrapper(Module):
177
200
  def __init__(
178
201
  self,
179
202
  net: ContinuousTransformerWrapper,
180
203
  ignore_index = -100,
181
204
  pad_value = 0,
182
- loss_fn = nn.MSELoss(reduction = 'none'),
205
+ loss_fn: Module | None = None,
183
206
  equal_loss_weight_batch = False # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
184
207
  ):
185
208
  super().__init__()
186
209
  self.net = net
187
210
  self.max_seq_len = net.max_seq_len
188
211
 
212
+ probabilistic = net.probabilistic
213
+ self.probabilistic = probabilistic
214
+
215
+ loss_fn = default(loss_fn, nn.MSELoss(reduction = 'none') if not probabilistic else GaussianNLL())
216
+
189
217
  self.loss_fn = loss_fn
190
218
  self.equal_loss_weight_batch = equal_loss_weight_batch
191
219
 
192
220
  @torch.no_grad()
193
- def generate(self, start_tokens, seq_len, **kwargs):
221
+ def generate(
222
+ self,
223
+ start_tokens,
224
+ seq_len,
225
+ **kwargs
226
+ ):
194
227
  device = start_tokens.device
195
228
  was_training = self.net.training
196
229
  num_dims = len(start_tokens.shape)
@@ -208,8 +241,13 @@ class ContinuousAutoregressiveWrapper(nn.Module):
208
241
  for _ in range(seq_len):
209
242
  x = out[:, -self.max_seq_len:]
210
243
 
211
- last = self.net(x, **kwargs)[:, -1:]
212
- out = torch.cat((out, last), dim = -2)
244
+ last_output = self.net(x, **kwargs)[..., -1:, :]
245
+
246
+ if self.probabilistic:
247
+ mean, var = last_output
248
+ last_output = torch.normal(mean, var)
249
+
250
+ out = cat((out, last_output), dim = -2)
213
251
 
214
252
  out = out[:, t:]
215
253
 
@@ -219,7 +257,11 @@ class ContinuousAutoregressiveWrapper(nn.Module):
219
257
  self.net.train(was_training)
220
258
  return out
221
259
 
222
- def forward(self, x, **kwargs):
260
+ def forward(
261
+ self,
262
+ x,
263
+ **kwargs
264
+ ):
223
265
  inp, target = x[:, :-1], x[:, 1:]
224
266
 
225
267
  assert 'prepend_embeds' not in kwargs
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.2.12
3
+ Version: 2.3.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
- x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
5
+ x_transformers/continuous.py,sha256=F7qGfpjJbPNvF49Dur0EqKE_KV6avSsO7CSvYC8pDh8,8216
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
11
11
  x_transformers/x_transformers.py,sha256=MF91aJGr2DOjIGe57uqwgyNxCExBg_tI9z7usAJMxOM,112401
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
14
- x_transformers-2.2.12.dist-info/METADATA,sha256=cWj_UYsNQNf2botGDqO7GkyiUh3msLww0EilFMMhRS0,88687
15
- x_transformers-2.2.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.2.12.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.2.12.dist-info/RECORD,,
14
+ x_transformers-2.3.0.dist-info/METADATA,sha256=ppv8n1fXmiOkiMvry_5WgmhP8zAHZoxIOjEVbHKDTtE,88686
15
+ x_transformers-2.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.0.dist-info/RECORD,,