x-transformers 1.37.3__py3-none-any.whl → 1.37.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +33 -3
- x_transformers/autoregressive_wrapper.py +1 -1
- x_transformers/nonautoregressive_wrapper.py +3 -1
- x_transformers/x_transformers.py +3 -1
- x_transformers/xl_autoregressive_wrapper.py +6 -4
- {x_transformers-1.37.3.dist-info → x_transformers-1.37.5.dist-info}/METADATA +1 -1
- x_transformers-1.37.5.dist-info/RECORD +15 -0
- x_transformers-1.37.3.dist-info/RECORD +0 -15
- {x_transformers-1.37.3.dist-info → x_transformers-1.37.5.dist-info}/LICENSE +0 -0
- {x_transformers-1.37.3.dist-info → x_transformers-1.37.5.dist-info}/WHEEL +0 -0
- {x_transformers-1.37.3.dist-info → x_transformers-1.37.5.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -13,7 +13,7 @@ from functools import wraps
|
|
13
13
|
from packaging import version
|
14
14
|
from dataclasses import dataclass
|
15
15
|
|
16
|
-
from einops import rearrange, repeat
|
16
|
+
from einops import rearrange, repeat, pack, unpack
|
17
17
|
|
18
18
|
# constants
|
19
19
|
|
@@ -39,9 +39,16 @@ def default(val, d):
|
|
39
39
|
def compact(arr):
|
40
40
|
return [*filter(exists, arr)]
|
41
41
|
|
42
|
-
|
42
|
+
@torch.jit.script
|
43
|
+
def softclamp(t: Tensor, value: float):
|
43
44
|
return (t / value).tanh() * value
|
44
45
|
|
46
|
+
def pack_one(t, pattern):
|
47
|
+
return pack([t], pattern)
|
48
|
+
|
49
|
+
def unpack_one(t, ps, pattern):
|
50
|
+
return unpack(t, ps, pattern)[0]
|
51
|
+
|
45
52
|
def once(fn):
|
46
53
|
called = False
|
47
54
|
@wraps(fn)
|
@@ -55,6 +62,18 @@ def once(fn):
|
|
55
62
|
|
56
63
|
print_once = once(print)
|
57
64
|
|
65
|
+
# alternative distance functions
|
66
|
+
|
67
|
+
def qk_l2_distance(q, k):
|
68
|
+
if k.ndim == 3:
|
69
|
+
k = repeat(k, 'b j d -> b h j d', h = q.shape[1])
|
70
|
+
|
71
|
+
q, packed_shape = pack_one(q, '* i d')
|
72
|
+
k, _ = pack_one(k, '* j d')
|
73
|
+
|
74
|
+
distance = torch.cdist(q, k)
|
75
|
+
return unpack_one(distance, packed_shape, '* i j')
|
76
|
+
|
58
77
|
# functions for creating causal mask
|
59
78
|
# need a special one for onnx cpu (no support for .triu)
|
60
79
|
|
@@ -80,6 +99,7 @@ class Attend(Module):
|
|
80
99
|
sparse_topk = None,
|
81
100
|
scale = None,
|
82
101
|
qk_norm = False,
|
102
|
+
l2_distance = False,
|
83
103
|
flash = False,
|
84
104
|
softclamp_logits = False,
|
85
105
|
logit_softclamp_value = 50.,
|
@@ -123,6 +143,11 @@ class Attend(Module):
|
|
123
143
|
assert not (flash and sigsoftmax), 'sigsoftmax not available for flash attention'
|
124
144
|
self.sigsoftmax = sigsoftmax
|
125
145
|
|
146
|
+
# l2 distance attention
|
147
|
+
|
148
|
+
assert not (flash and l2_distance), 'l2 distance attention does not work with flash attention just yet'
|
149
|
+
self.l2_distance = l2_distance
|
150
|
+
|
126
151
|
# add a key / value token composed of zeros
|
127
152
|
# in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
|
128
153
|
|
@@ -325,7 +350,12 @@ class Attend(Module):
|
|
325
350
|
|
326
351
|
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
|
327
352
|
|
328
|
-
|
353
|
+
if not self.l2_distance:
|
354
|
+
sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k)
|
355
|
+
else:
|
356
|
+
sim = -qk_l2_distance(q, k)
|
357
|
+
|
358
|
+
sim = sim * scale
|
329
359
|
|
330
360
|
if exists(prev_attn):
|
331
361
|
sim = sim + prev_attn
|
@@ -317,7 +317,7 @@ class AutoregressiveWrapper(Module):
|
|
317
317
|
**kwargs
|
318
318
|
)
|
319
319
|
|
320
|
-
loss_fn = F.cross_entropy if not self.net.
|
320
|
+
loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
|
321
321
|
|
322
322
|
loss = loss_fn(
|
323
323
|
rearrange(logits, 'b n c -> b c n'),
|
@@ -309,9 +309,11 @@ class NonAutoregressiveWrapper(nn.Module):
|
|
309
309
|
with context():
|
310
310
|
logits = self.net(masked, **kwargs)
|
311
311
|
|
312
|
+
loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
|
313
|
+
|
312
314
|
# cross entropy loss
|
313
315
|
|
314
|
-
loss =
|
316
|
+
loss = loss_fn(
|
315
317
|
logits[mask],
|
316
318
|
orig_seq[mask]
|
317
319
|
)
|
x_transformers/x_transformers.py
CHANGED
@@ -923,6 +923,7 @@ class Attention(Module):
|
|
923
923
|
qk_norm_groups = 1,
|
924
924
|
qk_norm_scale = 10,
|
925
925
|
qk_norm_dim_scale = False,
|
926
|
+
l2_distance = False,
|
926
927
|
one_kv_head = False,
|
927
928
|
kv_heads = None,
|
928
929
|
shared_kv = False,
|
@@ -1037,6 +1038,7 @@ class Attention(Module):
|
|
1037
1038
|
sparse_topk = sparse_topk,
|
1038
1039
|
qk_norm = qk_norm,
|
1039
1040
|
scale = qk_norm_scale if qk_norm else self.scale,
|
1041
|
+
l2_distance = l2_distance,
|
1040
1042
|
add_zero_kv = add_zero_kv,
|
1041
1043
|
flash = flash,
|
1042
1044
|
softclamp_logits = softclamp_logits,
|
@@ -2078,7 +2080,7 @@ class TransformerWrapper(Module):
|
|
2078
2080
|
|
2079
2081
|
# output type
|
2080
2082
|
|
2081
|
-
self.
|
2083
|
+
self.output_is_log_prob = mixture_of_softmax
|
2082
2084
|
|
2083
2085
|
self.to_mixture = None
|
2084
2086
|
self.combine_mixture = None
|
@@ -40,7 +40,7 @@ class XLAutoregressiveWrapper(nn.Module):
|
|
40
40
|
eos_token = None,
|
41
41
|
temperature = 1.,
|
42
42
|
filter_logits_fn = top_k,
|
43
|
-
|
43
|
+
filter_kwargs: dict = dict(),
|
44
44
|
mems = None,
|
45
45
|
**kwargs
|
46
46
|
):
|
@@ -88,7 +88,7 @@ class XLAutoregressiveWrapper(nn.Module):
|
|
88
88
|
mems = cache.mems
|
89
89
|
|
90
90
|
logits = logits[:, -1]
|
91
|
-
filtered_logits = filter_logits_fn(logits,
|
91
|
+
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
|
92
92
|
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
93
93
|
|
94
94
|
sample = torch.multinomial(probs, 1)
|
@@ -131,7 +131,9 @@ class XLAutoregressiveWrapper(nn.Module):
|
|
131
131
|
|
132
132
|
split_x = x.split(max_seq_len, dim = -1)
|
133
133
|
split_labels = labels.split(max_seq_len, dim = -1)
|
134
|
-
loss_weights = tuple(
|
134
|
+
loss_weights = tuple((t.shape[-1] / seq_len) for t in split_x)
|
135
|
+
|
136
|
+
loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
|
135
137
|
|
136
138
|
# go through each chunk and derive weighted losses
|
137
139
|
|
@@ -146,7 +148,7 @@ class XLAutoregressiveWrapper(nn.Module):
|
|
146
148
|
**kwargs
|
147
149
|
)
|
148
150
|
|
149
|
-
loss =
|
151
|
+
loss = loss_fn(
|
150
152
|
rearrange(logits, 'b n c -> b c n'),
|
151
153
|
chunk_labels,
|
152
154
|
ignore_index = ignore_index
|
@@ -0,0 +1,15 @@
|
|
1
|
+
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
+
x_transformers/attend.py,sha256=4RnX1yhWZIf8holucqnYXTIP7U1m40UpP58RZNT_2sM,13128
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
|
+
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
|
+
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
|
+
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
|
+
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
+
x_transformers/x_transformers.py,sha256=-2fj6QcDSfMI5lJA_fzOW2mdzdS1C1LD6jMBtGQY48E,83752
|
9
|
+
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
|
+
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
+
x_transformers-1.37.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.37.5.dist-info/METADATA,sha256=zHUhvP1bQjFbMtxnVO9iDESgXpGOQxuBCsm4b6K1w44,661
|
13
|
+
x_transformers-1.37.5.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.37.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.37.5.dist-info/RECORD,,
|
@@ -1,15 +0,0 @@
|
|
1
|
-
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=mV7duZ7ON2puS3-k4ctBifb2rq-jTJqrMbof7tI5jR4,12326
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=2FN4ZobFcdDGDGWEnUof_geb16dRGSJycZGwG899Pa4,10493
|
4
|
-
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
|
-
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
|
-
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
|
-
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
|
-
x_transformers/x_transformers.py,sha256=gOJBZzOJMu5RkIsxw9TZtde4Sx--D18yX8LjrYIsPbE,83677
|
9
|
-
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
|
-
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.37.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
-
x_transformers-1.37.3.dist-info/METADATA,sha256=SIGTCQMrLkyq_aksJAst0iXw9VfFT6QWlGvtUElbTMg,661
|
13
|
-
x_transformers-1.37.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
-
x_transformers-1.37.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
-
x_transformers-1.37.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|