x-transformers 1.42.6__py3-none-any.whl → 1.42.8__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +34 -13
- {x_transformers-1.42.6.dist-info → x_transformers-1.42.8.dist-info}/METADATA +1 -1
- {x_transformers-1.42.6.dist-info → x_transformers-1.42.8.dist-info}/RECORD +6 -6
- {x_transformers-1.42.6.dist-info → x_transformers-1.42.8.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.6.dist-info → x_transformers-1.42.8.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.6.dist-info → x_transformers-1.42.8.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -238,6 +238,13 @@ class TokenEmbedding(Module):
|
|
238
238
|
token_emb = self.emb(x.long())
|
239
239
|
return l2norm(token_emb) if self.l2norm_embed else token_emb
|
240
240
|
|
241
|
+
def init_(self):
|
242
|
+
if self.l2norm_embed:
|
243
|
+
nn.init.normal_(self.emb.weight, std=1e-5)
|
244
|
+
return
|
245
|
+
nn.init.kaiming_normal_(self.emb.weight)
|
246
|
+
|
247
|
+
|
241
248
|
# positional embeddings
|
242
249
|
|
243
250
|
class AbsolutePositionalEmbedding(Module):
|
@@ -1246,6 +1253,7 @@ class Attention(Module):
|
|
1246
1253
|
rel_pos = None,
|
1247
1254
|
attn_bias = None,
|
1248
1255
|
rotary_pos_emb = None,
|
1256
|
+
pos = None, # for custom alibi positions
|
1249
1257
|
prev_attn = None,
|
1250
1258
|
mem = None,
|
1251
1259
|
mem_mask = None,
|
@@ -1392,7 +1400,14 @@ class Attention(Module):
|
|
1392
1400
|
|
1393
1401
|
if exists(rel_pos):
|
1394
1402
|
assert not exists(attn_bias)
|
1395
|
-
|
1403
|
+
|
1404
|
+
if exists(pos):
|
1405
|
+
assert isinstance(rel_pos, AlibiPositionalBias), 'only alibi allowed for custom positions at the moment'
|
1406
|
+
# allow for custom positions to be passed in
|
1407
|
+
attn_bias = rel_pos.forward_custom_pos(pos)
|
1408
|
+
else:
|
1409
|
+
attn_bias = rel_pos(i, j)
|
1410
|
+
|
1396
1411
|
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
|
1397
1412
|
|
1398
1413
|
# prepare data dependent alibi from forgetting transformers paper, if needed
|
@@ -1843,6 +1858,7 @@ class AttentionLayers(Module):
|
|
1843
1858
|
cache_age = 1,
|
1844
1859
|
return_hiddens = False,
|
1845
1860
|
rotary_pos_emb = None,
|
1861
|
+
pos = None,
|
1846
1862
|
attn_bias = None,
|
1847
1863
|
condition = None,
|
1848
1864
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
@@ -1906,7 +1922,9 @@ class AttentionLayers(Module):
|
|
1906
1922
|
maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
|
1907
1923
|
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
|
1908
1924
|
|
1909
|
-
|
1925
|
+
if not exists(pos):
|
1926
|
+
pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
|
1927
|
+
|
1910
1928
|
rotary_pos_emb = self.rotary_pos_emb(pos)
|
1911
1929
|
|
1912
1930
|
# assume cached key / values
|
@@ -2030,7 +2048,7 @@ class AttentionLayers(Module):
|
|
2030
2048
|
# forward depending on layer type
|
2031
2049
|
|
2032
2050
|
if layer_type == 'a':
|
2033
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
2051
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
2034
2052
|
elif layer_type == 'c':
|
2035
2053
|
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
|
2036
2054
|
elif layer_type == 'f':
|
@@ -2250,7 +2268,8 @@ class TransformerWrapper(Module):
|
|
2250
2268
|
token_emb: TokenEmbedding | None = None,
|
2251
2269
|
mixture_of_softmax = False,
|
2252
2270
|
mixture_of_softmax_k = 4,
|
2253
|
-
sigsoftmax_logits = False
|
2271
|
+
sigsoftmax_logits = False,
|
2272
|
+
to_logits: Module | None = None,
|
2254
2273
|
):
|
2255
2274
|
super().__init__()
|
2256
2275
|
|
@@ -2352,11 +2371,12 @@ class TransformerWrapper(Module):
|
|
2352
2371
|
if return_only_embed:
|
2353
2372
|
self.to_logits = None
|
2354
2373
|
elif tie_embedding:
|
2374
|
+
assert isinstance(token_emb, TokenEmbedding), 'can only tie embedding if using `TokenEmbedding`'
|
2355
2375
|
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
|
2356
2376
|
elif num_output_heads > 1:
|
2357
2377
|
self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)])
|
2358
2378
|
else:
|
2359
|
-
self.to_logits = LinearNoBias(dim, logits_dim)
|
2379
|
+
self.to_logits = LinearNoBias(dim, logits_dim) if not exists(to_logits) else to_logits
|
2360
2380
|
|
2361
2381
|
# memory tokens (like [cls]) from Memory Transformers paper
|
2362
2382
|
|
@@ -2377,13 +2397,12 @@ class TransformerWrapper(Module):
|
|
2377
2397
|
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
|
2378
2398
|
|
2379
2399
|
def init_(self):
|
2400
|
+
if hasattr(self.token_emb, 'init_'):
|
2401
|
+
self.token_emb.init_()
|
2402
|
+
|
2380
2403
|
if self.l2norm_embed:
|
2381
|
-
nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
|
2382
2404
|
if not isinstance(self.pos_emb, always):
|
2383
2405
|
nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
|
2384
|
-
return
|
2385
|
-
|
2386
|
-
nn.init.kaiming_normal_(self.token_emb.emb.weight)
|
2387
2406
|
|
2388
2407
|
def forward(
|
2389
2408
|
self,
|
@@ -2406,7 +2425,9 @@ class TransformerWrapper(Module):
|
|
2406
2425
|
attn_z_loss_weight = 1e-4,
|
2407
2426
|
seq_start_pos = None,
|
2408
2427
|
cache: LayerIntermediates | None = None,
|
2409
|
-
|
2428
|
+
token_emb_kwargs = dict(),
|
2429
|
+
to_logits_kwargs = dict(),
|
2430
|
+
**kwargs,
|
2410
2431
|
):
|
2411
2432
|
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
|
2412
2433
|
|
@@ -2417,7 +2438,7 @@ class TransformerWrapper(Module):
|
|
2417
2438
|
|
2418
2439
|
external_pos_emb = exists(pos) and pos.dtype != torch.long
|
2419
2440
|
pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
|
2420
|
-
x = self.token_emb(x) + pos_emb
|
2441
|
+
x = self.token_emb(x, **token_emb_kwargs) + pos_emb
|
2421
2442
|
|
2422
2443
|
# add additional embeddings
|
2423
2444
|
|
@@ -2572,9 +2593,9 @@ class TransformerWrapper(Module):
|
|
2572
2593
|
|
2573
2594
|
if not return_embeddings:
|
2574
2595
|
if self.has_multiple_heads:
|
2575
|
-
logits = tuple(fn(x) for fn in self.to_logits)
|
2596
|
+
logits = tuple(fn(x, **to_logits_kwargs) for fn in self.to_logits)
|
2576
2597
|
else:
|
2577
|
-
logits = self.to_logits(x)
|
2598
|
+
logits = self.to_logits(x, **to_logits_kwargs)
|
2578
2599
|
|
2579
2600
|
# maybe sig softmax
|
2580
2601
|
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=275B_yDHePxUvlLcMNgnCUmZ1qZEkwBrpk6IA8n-pnY,93550
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.8.dist-info/METADATA,sha256=1d2BVA6iHKpT4UzbYxw16ijAFGJT-u29zTnYtV6Lp3w,689
|
14
|
+
x_transformers-1.42.8.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
15
|
+
x_transformers-1.42.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|