x-transformers 1.32.4__py3-none-any.whl → 1.32.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,6 +20,8 @@ from x_transformers.continuous import (
20
20
  ContinuousAutoregressiveWrapper
21
21
  )
22
22
 
23
+ from x_transformers.multi_input import MultiInputTransformerWrapper
24
+
23
25
  from x_transformers.xval import (
24
26
  XValTransformerWrapper,
25
27
  XValAutoregressiveWrapper
@@ -0,0 +1,294 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn, Tensor
5
+ from torch.nn import Module, ModuleDict
6
+ import torch.nn.functional as F
7
+
8
+ from typing import Dict
9
+
10
+ from einops import pack, repeat, unpack
11
+
12
+ from x_transformers.x_transformers import (
13
+ AttentionLayers,
14
+ ScaledSinusoidalEmbedding,
15
+ AbsolutePositionalEmbedding,
16
+ LayerIntermediates,
17
+ LayerNorm,
18
+ always,
19
+ pad_at_dim,
20
+ is_empty
21
+ )
22
+
23
+ # helper functions
24
+
25
+ def exists(val):
26
+ return val is not None
27
+
28
+ def default(val, d):
29
+ if exists(val):
30
+ return val
31
+ return d() if callable(d) else d
32
+
33
+
34
+ class MultiInputTransformerWrapper(Module):
35
+ def __init__(
36
+ self,
37
+ *,
38
+ num_tokens: Dict[str, int] = dict(),
39
+ max_seq_len,
40
+ attn_layers: AttentionLayers,
41
+ emb_dim = None,
42
+ max_mem_len = 0,
43
+ shift_mem_down = 0,
44
+ emb_dropout = 0.,
45
+ post_emb_norm = False,
46
+ num_memory_tokens = None,
47
+ memory_tokens_interspersed_every = None,
48
+ tie_embedding = False,
49
+ logits_dim = None,
50
+ return_only_embed = False,
51
+ num_output_heads = 1,
52
+ use_abs_pos_emb = True,
53
+ scaled_sinu_pos_emb = False,
54
+ emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
55
+ attn_z_loss_weight = 1e-4,
56
+ ):
57
+ super().__init__()
58
+
59
+ dim = attn_layers.dim
60
+ emb_dim = default(emb_dim, dim)
61
+ self.emb_dim = emb_dim
62
+
63
+ self.max_seq_len = max_seq_len
64
+ self.max_mem_len = max_mem_len
65
+ self.shift_mem_down = shift_mem_down
66
+
67
+ no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)
68
+
69
+ if no_abs_pos_emb:
70
+ self.pos_emb = always(0)
71
+ elif scaled_sinu_pos_emb:
72
+ self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
73
+ else:
74
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len)
75
+
76
+ # additional embeddings - say type embedding from BERT
77
+
78
+ self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(one_num_tokens, emb_dim) for name, one_num_tokens in num_tokens.items()})
79
+
80
+ # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
81
+
82
+ self.emb_frac_gradient = emb_frac_gradient
83
+
84
+ self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
85
+ self.emb_dropout = nn.Dropout(emb_dropout)
86
+
87
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
88
+ self.attn_layers = attn_layers
89
+
90
+ assert num_output_heads > 0
91
+
92
+ # output head, usually to logits of num_tokens
93
+
94
+ logits_dim = default(logits_dim, num_tokens)
95
+
96
+ self.has_multiple_heads = False
97
+
98
+ if return_only_embed:
99
+ self.to_logits = None
100
+ elif tie_embedding:
101
+ self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
102
+ elif num_output_heads > 1:
103
+ self.has_multiple_heads = True
104
+ self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
105
+ else:
106
+ self.to_logits = nn.Linear(dim, logits_dim, bias = False)
107
+
108
+ # memory tokens (like [cls]) from Memory Transformers paper
109
+
110
+ num_memory_tokens = default(num_memory_tokens, 0)
111
+ self.num_memory_tokens = num_memory_tokens
112
+ if num_memory_tokens > 0:
113
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
114
+
115
+ self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
116
+
117
+ # whether can do cached kv decoding
118
+
119
+ self.can_cache_kv = self.num_memory_tokens == 0
120
+ self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
121
+
122
+ def forward(
123
+ self,
124
+ x: Dict[str, Tensor],
125
+ return_embeddings = False,
126
+ return_logits_and_embeddings = False,
127
+ return_intermediates = False,
128
+ mask = None,
129
+ return_mems = False,
130
+ return_attn = False,
131
+ mems = None,
132
+ mem_masks = None,
133
+ pos = None,
134
+ prepend_embeds = None,
135
+ prepend_mask = None,
136
+ sum_embeds = None,
137
+ return_attn_z_loss = False,
138
+ attn_z_loss_weight = 1e-4,
139
+ seq_start_pos = None,
140
+ cache: LayerIntermediates | None = None,
141
+ **kwargs
142
+ ):
143
+ assert not is_empty(x)
144
+ first_input = list(x.values())[0]
145
+
146
+ b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *first_input.shape, first_input.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
147
+
148
+ return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
149
+ return_embeddings = return_embeddings | (not exists(self.to_logits))
150
+
151
+ # token embedding
152
+
153
+ assert len(x) == len(self.embeds)
154
+
155
+ token_emb = 0.
156
+
157
+ for name, embed_id in x.items():
158
+ embed_key = f'{name}_embed'
159
+
160
+ assert embed_key in self.embeds
161
+ embed = self.embeds[embed_key](embed_id)
162
+
163
+ token_emb = token_emb + embed
164
+
165
+ # absolute positional embedding
166
+
167
+ external_pos_emb = exists(pos) and pos.dtype != torch.long
168
+ pos_emb = self.pos_emb(first_input, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
169
+
170
+ token_emb = token_emb + pos_emb
171
+
172
+ # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
173
+
174
+ if exists(sum_embeds):
175
+ token_emb = token_emb + sum_embeds
176
+
177
+ # set back to `x`
178
+
179
+ x = token_emb
180
+
181
+ # post embedding norm, purportedly leads to greater stabilization
182
+
183
+ x = self.post_emb_norm(x)
184
+
185
+ # whether to append embeds, as in PaLI, for image embeddings
186
+
187
+ if exists(prepend_embeds):
188
+ prepend_seq, prepend_dim = prepend_embeds.shape[1:]
189
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
190
+
191
+ x = torch.cat((prepend_embeds, x), dim = -2)
192
+
193
+ if exists(prepend_mask) or exists(mask):
194
+ mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
195
+ prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
196
+
197
+ mask = torch.cat((prepend_mask, mask), dim = -1)
198
+
199
+ # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
200
+
201
+ if emb_frac_gradient < 1:
202
+ assert emb_frac_gradient > 0
203
+ x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
204
+
205
+ # embedding dropout
206
+
207
+ x = self.emb_dropout(x)
208
+
209
+ x = self.project_emb(x)
210
+
211
+ if has_memory_tokens:
212
+ mem_every = self.memory_tokens_interspersed_every
213
+
214
+ if exists(mem_every):
215
+ assert mem_every > 0
216
+ assert isinstance(self.attn_layers, Decoder), 'only for decoder'
217
+ next_seq_len = math.ceil(n / mem_every) * mem_every
218
+
219
+ x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.)
220
+ x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every)
221
+
222
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0])
223
+ x, mem_packed_shape = pack((mem, x), 'b * d')
224
+
225
+ # auto-handle masking after appending memory tokens
226
+ if not exists(mem_every) and exists(mask):
227
+ mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
228
+
229
+ if exists(mem_every):
230
+ x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
231
+
232
+ if self.shift_mem_down and exists(mems):
233
+ mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
234
+ mems = [*mems_r, *mems_l]
235
+
236
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
237
+
238
+ # handle memories post-attention
239
+
240
+ if has_memory_tokens:
241
+ if exists(mem_every):
242
+ x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
243
+
244
+ mem, x = unpack(x, mem_packed_shape, 'b * d')
245
+
246
+ intermediates.memory_tokens = mem
247
+
248
+ if exists(mem_every):
249
+ x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
250
+
251
+ x = x[:, :n]
252
+
253
+ # projecting to logits
254
+
255
+ if not return_embeddings:
256
+ if self.has_multiple_heads:
257
+ logits = tuple(fn(x) for fn in self.to_logits)
258
+ else:
259
+ logits = self.to_logits(x)
260
+
261
+ # different returns
262
+
263
+ if return_logits_and_embeddings:
264
+ out = (logits, x)
265
+ elif return_embeddings:
266
+ out = x
267
+ else:
268
+ out = logits
269
+
270
+ # aux loss
271
+
272
+ if return_attn_z_loss:
273
+ pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates]
274
+ intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
275
+ return_intermediates = True
276
+
277
+ if return_mems:
278
+ hiddens = intermediates.hiddens
279
+ new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
280
+ new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
281
+
282
+ if not return_intermediates:
283
+ return out, new_mems
284
+
285
+ intermediates.mems = new_mems
286
+
287
+ if return_intermediates:
288
+ return out, intermediates
289
+
290
+ if return_attn:
291
+ attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
292
+ return out, attn_maps
293
+
294
+ return out
@@ -51,6 +51,9 @@ def cast_tuple(val, depth):
51
51
  def divisible_by(num, den):
52
52
  return (num % den) == 0
53
53
 
54
+ def is_empty(x):
55
+ return len(x) == 0
56
+
54
57
  def maybe(fn):
55
58
  @wraps(fn)
56
59
  def inner(x, *args, **kwargs):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.4
3
+ Version: 1.32.6
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,14 +1,15 @@
1
- x_transformers/__init__.py,sha256=5ms39Df8osTUHQ-XTCgP4vSUA4UiNpim9VXJtrLrIvQ,724
1
+ x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
2
  x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,11349
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
+ x_transformers/multi_input.py,sha256=QvYrueLPcfcm0gvoSZYCd7zVgUTi2i0fZkvXowCwx_s,9794
6
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=o7At5F35Paih1a_-rxqcIP1n4B-ARVJ_ZL2QkOnTnSQ,76655
8
+ x_transformers/x_transformers.py,sha256=5DHbYgx0RPg9QHvfBs2qHWrtn4Jji-q0d1MRBbcRPR8,76696
8
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.32.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.32.4.dist-info/METADATA,sha256=rAX2cTBnI50T0Tsa00KgBTz6lzV36ACppE6H2WPLZI4,661
12
- x_transformers-1.32.4.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
13
- x_transformers-1.32.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.32.4.dist-info/RECORD,,
11
+ x_transformers-1.32.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.32.6.dist-info/METADATA,sha256=_Uhhkxnq0aIykqJxbQdQOpxcnYJcciV5Z9SwghDiTpQ,661
13
+ x_transformers-1.32.6.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
+ x_transformers-1.32.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.32.6.dist-info/RECORD,,