x-transformers 1.37.8__tar.gz → 1.37.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-1.37.8/x_transformers.egg-info → x_transformers-1.37.9}/PKG-INFO +1 -1
- {x_transformers-1.37.8 → x_transformers-1.37.9}/README.md +9 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/setup.py +1 -1
- {x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers/attend.py +21 -2
- {x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers/x_transformers.py +2 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.37.8 → x_transformers-1.37.9}/LICENSE +0 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/setup.cfg +0 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/tests/test_x_transformers.py +0 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers/__init__.py +0 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers/continuous.py +0 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers/dpo.py +0 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers/xval.py +0 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers.egg-info/top_level.txt +0 -0
@@ -2252,4 +2252,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2252
2252
|
}
|
2253
2253
|
```
|
2254
2254
|
|
2255
|
+
```bibtex
|
2256
|
+
@inproceedings{Ramapuram2024TheoryAA,
|
2257
|
+
title = {Theory, Analysis, and Best Practices for Sigmoid Self-Attention},
|
2258
|
+
author = {Jason Ramapuram and Federico Danieli and Eeshan Gunesh Dhekane and Floris Weers and Dan Busbridge and Pierre Ablin and Tatiana Likhomanenko and Jagrit Digani and Zijin Gu and Amitis Shidani and Russ Webb},
|
2259
|
+
year = {2024},
|
2260
|
+
url = {https://api.semanticscholar.org/CorpusID:272463580}
|
2261
|
+
}
|
2262
|
+
```
|
2263
|
+
|
2255
2264
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -36,6 +36,9 @@ def exists(val):
|
|
36
36
|
def default(val, d):
|
37
37
|
return val if exists(val) else d
|
38
38
|
|
39
|
+
def at_most_one_of(*bools):
|
40
|
+
return sum([*map(int, bools)]) <= 1
|
41
|
+
|
39
42
|
def compact(arr):
|
40
43
|
return [*filter(exists, arr)]
|
41
44
|
|
@@ -100,6 +103,7 @@ class Attend(Module):
|
|
100
103
|
scale = None,
|
101
104
|
qk_norm = False,
|
102
105
|
l2_distance = False,
|
106
|
+
sigmoid = False,
|
103
107
|
flash = False,
|
104
108
|
softclamp_logits = False,
|
105
109
|
logit_softclamp_value = 50.,
|
@@ -116,10 +120,25 @@ class Attend(Module):
|
|
116
120
|
super().__init__()
|
117
121
|
self.scale = scale
|
118
122
|
|
123
|
+
# causal related
|
124
|
+
|
119
125
|
self.causal = causal
|
120
126
|
self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
|
121
127
|
|
122
|
-
|
128
|
+
# attention type
|
129
|
+
|
130
|
+
assert not (flash and sigmoid), 'sigmoid attention not available for flash'
|
131
|
+
assert at_most_one_of(sigmoid, l2_distance)
|
132
|
+
|
133
|
+
self.sigmoid = sigmoid
|
134
|
+
|
135
|
+
if not sigmoid:
|
136
|
+
softmax_fn = partial(F.softmax, dim = -1)
|
137
|
+
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
|
138
|
+
else:
|
139
|
+
self.attn_fn = F.sigmoid
|
140
|
+
|
141
|
+
# dropouts
|
123
142
|
|
124
143
|
self.dropout = dropout
|
125
144
|
self.attn_dropout = nn.Dropout(dropout)
|
@@ -410,7 +429,7 @@ class Attend(Module):
|
|
410
429
|
if self.sigsoftmax:
|
411
430
|
sim = sim + sim.sigmoid().log()
|
412
431
|
|
413
|
-
attn = self.attn_fn(sim
|
432
|
+
attn = self.attn_fn(sim)
|
414
433
|
attn = attn.type(dtype)
|
415
434
|
|
416
435
|
post_softmax_attn = attn
|
@@ -924,6 +924,7 @@ class Attention(Module):
|
|
924
924
|
qk_norm_scale = 10,
|
925
925
|
qk_norm_dim_scale = False,
|
926
926
|
l2_distance = False,
|
927
|
+
sigmoid = False,
|
927
928
|
one_kv_head = False,
|
928
929
|
kv_heads = None,
|
929
930
|
shared_kv = False,
|
@@ -1039,6 +1040,7 @@ class Attention(Module):
|
|
1039
1040
|
qk_norm = qk_norm,
|
1040
1041
|
scale = qk_norm_scale if qk_norm else self.scale,
|
1041
1042
|
l2_distance = l2_distance,
|
1043
|
+
sigmoid = sigmoid,
|
1042
1044
|
add_zero_kv = add_zero_kv,
|
1043
1045
|
flash = flash,
|
1044
1046
|
softclamp_logits = softclamp_logits,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.37.8 → x_transformers-1.37.9}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|