x-transformers 1.32.2__py3-none-any.whl → 1.32.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,6 +20,8 @@ from x_transformers.continuous import (
20
20
  ContinuousAutoregressiveWrapper
21
21
  )
22
22
 
23
+ from x_transformers.multi_input import MultiInputTransformerWrapper
24
+
23
25
  from x_transformers.xval import (
24
26
  XValTransformerWrapper,
25
27
  XValAutoregressiveWrapper
@@ -0,0 +1,292 @@
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from torch.nn import Module, ModuleDict
4
+ import torch.nn.functional as F
5
+
6
+ from typing import Dict
7
+
8
+ from einops import pack, repeat, unpack
9
+
10
+ from x_transformers.x_transformers import (
11
+ AttentionLayers,
12
+ ScaledSinusoidalEmbedding,
13
+ AbsolutePositionalEmbedding,
14
+ LayerIntermediates,
15
+ LayerNorm,
16
+ always,
17
+ pad_at_dim,
18
+ is_empty
19
+ )
20
+
21
+ # helper functions
22
+
23
+ def exists(val):
24
+ return val is not None
25
+
26
+ def default(val, d):
27
+ if exists(val):
28
+ return val
29
+ return d() if callable(d) else d
30
+
31
+
32
+ class MultiInputTransformerWrapper(Module):
33
+ def __init__(
34
+ self,
35
+ *,
36
+ num_tokens: Dict[str, int] = dict(),
37
+ max_seq_len,
38
+ attn_layers: AttentionLayers,
39
+ emb_dim = None,
40
+ max_mem_len = 0,
41
+ shift_mem_down = 0,
42
+ emb_dropout = 0.,
43
+ post_emb_norm = False,
44
+ num_memory_tokens = None,
45
+ memory_tokens_interspersed_every = None,
46
+ tie_embedding = False,
47
+ logits_dim = None,
48
+ return_only_embed = False,
49
+ num_output_heads = 1,
50
+ use_abs_pos_emb = True,
51
+ scaled_sinu_pos_emb = False,
52
+ emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
53
+ attn_z_loss_weight = 1e-4,
54
+ ):
55
+ super().__init__()
56
+
57
+ dim = attn_layers.dim
58
+ emb_dim = default(emb_dim, dim)
59
+ self.emb_dim = emb_dim
60
+
61
+ self.max_seq_len = max_seq_len
62
+ self.max_mem_len = max_mem_len
63
+ self.shift_mem_down = shift_mem_down
64
+
65
+ no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)
66
+
67
+ if no_abs_pos_emb:
68
+ self.pos_emb = always(0)
69
+ elif scaled_sinu_pos_emb:
70
+ self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
71
+ else:
72
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len)
73
+
74
+ # additional embeddings - say type embedding from BERT
75
+
76
+ self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(one_num_tokens, emb_dim) for name, one_num_tokens in num_tokens.items()})
77
+
78
+ # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
79
+
80
+ self.emb_frac_gradient = emb_frac_gradient
81
+
82
+ self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
83
+ self.emb_dropout = nn.Dropout(emb_dropout)
84
+
85
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
86
+ self.attn_layers = attn_layers
87
+
88
+ assert num_output_heads > 0
89
+
90
+ # output head, usually to logits of num_tokens
91
+
92
+ logits_dim = default(logits_dim, num_tokens)
93
+
94
+ self.has_multiple_heads = False
95
+
96
+ if return_only_embed:
97
+ self.to_logits = None
98
+ elif tie_embedding:
99
+ self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
100
+ elif num_output_heads > 1:
101
+ self.has_multiple_heads = True
102
+ self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
103
+ else:
104
+ self.to_logits = nn.Linear(dim, logits_dim, bias = False)
105
+
106
+ # memory tokens (like [cls]) from Memory Transformers paper
107
+
108
+ num_memory_tokens = default(num_memory_tokens, 0)
109
+ self.num_memory_tokens = num_memory_tokens
110
+ if num_memory_tokens > 0:
111
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
112
+
113
+ self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
114
+
115
+ # whether can do cached kv decoding
116
+
117
+ self.can_cache_kv = self.num_memory_tokens == 0
118
+ self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
119
+
120
+ def forward(
121
+ self,
122
+ x: Dict[str, Tensor],
123
+ return_embeddings = False,
124
+ return_logits_and_embeddings = False,
125
+ return_intermediates = False,
126
+ mask = None,
127
+ return_mems = False,
128
+ return_attn = False,
129
+ mems = None,
130
+ mem_masks = None,
131
+ pos = None,
132
+ prepend_embeds = None,
133
+ prepend_mask = None,
134
+ sum_embeds = None,
135
+ return_attn_z_loss = False,
136
+ attn_z_loss_weight = 1e-4,
137
+ seq_start_pos = None,
138
+ cache: LayerIntermediates | None = None,
139
+ **kwargs
140
+ ):
141
+ assert not is_empty(x)
142
+ first_input = list(x.values())[0]
143
+
144
+ b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *first_input.shape, first_input.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
145
+
146
+ return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
147
+ return_embeddings = return_embeddings | (not exists(self.to_logits))
148
+
149
+ # token embedding
150
+
151
+ assert len(x) == len(self.embeds)
152
+
153
+ token_emb = 0.
154
+
155
+ for name, embed_id in x.items():
156
+ embed_key = f'{name}_embed'
157
+
158
+ assert embed_key in self.embeds
159
+ embed = self.embeds[embed_key](embed_id)
160
+
161
+ token_emb = token_emb + embed
162
+
163
+ # absolute positional embedding
164
+
165
+ external_pos_emb = exists(pos) and pos.dtype != torch.long
166
+ pos_emb = self.pos_emb(first_input, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
167
+
168
+ token_emb = token_emb + pos_emb
169
+
170
+ # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
171
+
172
+ if exists(sum_embeds):
173
+ token_emb = token_emb + sum_embeds
174
+
175
+ # set back to `x`
176
+
177
+ x = token_emb
178
+
179
+ # post embedding norm, purportedly leads to greater stabilization
180
+
181
+ x = self.post_emb_norm(x)
182
+
183
+ # whether to append embeds, as in PaLI, for image embeddings
184
+
185
+ if exists(prepend_embeds):
186
+ prepend_seq, prepend_dim = prepend_embeds.shape[1:]
187
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
188
+
189
+ x = torch.cat((prepend_embeds, x), dim = -2)
190
+
191
+ if exists(prepend_mask) or exists(mask):
192
+ mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
193
+ prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
194
+
195
+ mask = torch.cat((prepend_mask, mask), dim = -1)
196
+
197
+ # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
198
+
199
+ if emb_frac_gradient < 1:
200
+ assert emb_frac_gradient > 0
201
+ x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
202
+
203
+ # embedding dropout
204
+
205
+ x = self.emb_dropout(x)
206
+
207
+ x = self.project_emb(x)
208
+
209
+ if has_memory_tokens:
210
+ mem_every = self.memory_tokens_interspersed_every
211
+
212
+ if exists(mem_every):
213
+ assert mem_every > 0
214
+ assert isinstance(self.attn_layers, Decoder), 'only for decoder'
215
+ next_seq_len = math.ceil(n / mem_every) * mem_every
216
+
217
+ x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.)
218
+ x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every)
219
+
220
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0])
221
+ x, mem_packed_shape = pack((mem, x), 'b * d')
222
+
223
+ # auto-handle masking after appending memory tokens
224
+ if not exists(mem_every) and exists(mask):
225
+ mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
226
+
227
+ if exists(mem_every):
228
+ x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
229
+
230
+ if self.shift_mem_down and exists(mems):
231
+ mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
232
+ mems = [*mems_r, *mems_l]
233
+
234
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
235
+
236
+ # handle memories post-attention
237
+
238
+ if has_memory_tokens:
239
+ if exists(mem_every):
240
+ x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
241
+
242
+ mem, x = unpack(x, mem_packed_shape, 'b * d')
243
+
244
+ intermediates.memory_tokens = mem
245
+
246
+ if exists(mem_every):
247
+ x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
248
+
249
+ x = x[:, :n]
250
+
251
+ # projecting to logits
252
+
253
+ if not return_embeddings:
254
+ if self.has_multiple_heads:
255
+ logits = tuple(fn(x) for fn in self.to_logits)
256
+ else:
257
+ logits = self.to_logits(x)
258
+
259
+ # different returns
260
+
261
+ if return_logits_and_embeddings:
262
+ out = (logits, x)
263
+ elif return_embeddings:
264
+ out = x
265
+ else:
266
+ out = logits
267
+
268
+ # aux loss
269
+
270
+ if return_attn_z_loss:
271
+ pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates]
272
+ intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
273
+ return_intermediates = True
274
+
275
+ if return_mems:
276
+ hiddens = intermediates.hiddens
277
+ new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
278
+ new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
279
+
280
+ if not return_intermediates:
281
+ return out, new_mems
282
+
283
+ intermediates.mems = new_mems
284
+
285
+ if return_intermediates:
286
+ return out, intermediates
287
+
288
+ if return_attn:
289
+ attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
290
+ return out, attn_maps
291
+
292
+ return out
@@ -51,6 +51,9 @@ def cast_tuple(val, depth):
51
51
  def divisible_by(num, den):
52
52
  return (num % den) == 0
53
53
 
54
+ def is_empty(x):
55
+ return len(x) == 0
56
+
54
57
  def maybe(fn):
55
58
  @wraps(fn)
56
59
  def inner(x, *args, **kwargs):
@@ -1899,6 +1902,7 @@ class TransformerWrapper(Module):
1899
1902
  memory_tokens_interspersed_every = None,
1900
1903
  tie_embedding = False,
1901
1904
  logits_dim = None,
1905
+ return_only_embed = False,
1902
1906
  num_output_heads = 1,
1903
1907
  use_abs_pos_emb = True,
1904
1908
  scaled_sinu_pos_emb = False,
@@ -1948,13 +1952,17 @@ class TransformerWrapper(Module):
1948
1952
 
1949
1953
  self.init_()
1950
1954
 
1955
+ assert num_output_heads > 0
1956
+
1951
1957
  # output head, usually to logits of num_tokens
1952
1958
 
1953
1959
  logits_dim = default(logits_dim, num_tokens)
1954
1960
 
1955
1961
  self.has_multiple_heads = False
1956
1962
 
1957
- if tie_embedding:
1963
+ if return_only_embed:
1964
+ self.to_logits = None
1965
+ elif tie_embedding:
1958
1966
  self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
1959
1967
  elif num_output_heads > 1:
1960
1968
  self.has_multiple_heads = True
@@ -2008,7 +2016,9 @@ class TransformerWrapper(Module):
2008
2016
  **kwargs
2009
2017
  ):
2010
2018
  b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
2019
+
2011
2020
  return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2021
+ return_embeddings = return_embeddings | (not exists(self.to_logits))
2012
2022
 
2013
2023
  # absolute positional embedding
2014
2024
 
@@ -2018,6 +2028,8 @@ class TransformerWrapper(Module):
2018
2028
 
2019
2029
  # add additional embeddings
2020
2030
 
2031
+ assert not (exists(self.embeds) ^ (len(embed_ids) > 0)), '`embed_num_tokens` must be defined on `TransformerWrapper`'
2032
+
2021
2033
  if exists(self.embeds):
2022
2034
  assert len(embed_ids) == len(self.embeds)
2023
2035
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.2
3
+ Version: 1.32.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,14 +1,15 @@
1
- x_transformers/__init__.py,sha256=5ms39Df8osTUHQ-XTCgP4vSUA4UiNpim9VXJtrLrIvQ,724
1
+ x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
2
  x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,11349
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
+ x_transformers/multi_input.py,sha256=S4hJffAhN7Ubbd92_9YmZ0VOzdu1u-1e2IgWZxYz9BU,9758
6
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=1QG7zUe89h1R5VDMoKEAkvdRRDkzQ7h6npkqblxxR6g,76312
8
+ x_transformers/x_transformers.py,sha256=5DHbYgx0RPg9QHvfBs2qHWrtn4Jji-q0d1MRBbcRPR8,76696
8
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.32.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.32.2.dist-info/METADATA,sha256=U0Kh4e7UiL-0hLDZb0P3McdvTnzTeFyVwtoXFffzQ-M,661
12
- x_transformers-1.32.2.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
13
- x_transformers-1.32.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.32.2.dist-info/RECORD,,
11
+ x_transformers-1.32.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.32.5.dist-info/METADATA,sha256=HSujd9BjY5lAIY5QxDqf4JDecEOXrvWQoe2zjpRaGG0,661
13
+ x_transformers-1.32.5.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
+ x_transformers-1.32.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.32.5.dist-info/RECORD,,