x-transformers 1.32.10__tar.gz → 1.32.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-1.32.10/x_transformers.egg-info → x_transformers-1.32.12}/PKG-INFO +1 -1
- {x_transformers-1.32.10 → x_transformers-1.32.12}/setup.py +1 -1
- {x_transformers-1.32.10 → x_transformers-1.32.12}/tests/test_x_transformers.py +21 -2
- {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/x_transformers.py +20 -4
- {x_transformers-1.32.10 → x_transformers-1.32.12/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.32.10 → x_transformers-1.32.12}/LICENSE +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/README.md +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/setup.cfg +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/__init__.py +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/attend.py +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/continuous.py +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/dpo.py +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/xval.py +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers.egg-info/top_level.txt +0 -0
@@ -159,7 +159,6 @@ def test_multiple_input_embeds():
|
|
159
159
|
assert embed.shape == (2, 1024, 128)
|
160
160
|
|
161
161
|
def test_average_pool_embed():
|
162
|
-
|
163
162
|
model = TransformerWrapper(
|
164
163
|
num_tokens = 20000,
|
165
164
|
max_seq_len = 1024,
|
@@ -180,7 +179,6 @@ def test_average_pool_embed():
|
|
180
179
|
assert logits.shape == (2, 20000)
|
181
180
|
|
182
181
|
def test_cls_token():
|
183
|
-
|
184
182
|
model = TransformerWrapper(
|
185
183
|
num_tokens = 20000,
|
186
184
|
max_seq_len = 1024,
|
@@ -199,3 +197,24 @@ def test_cls_token():
|
|
199
197
|
logits = model(x, mask = mask)
|
200
198
|
|
201
199
|
assert logits.shape == (2, 20000)
|
200
|
+
|
201
|
+
def test_squeeze_logit_dim_one():
|
202
|
+
model = TransformerWrapper(
|
203
|
+
num_tokens = 20000,
|
204
|
+
max_seq_len = 1024,
|
205
|
+
logits_dim = 1,
|
206
|
+
average_pool_embed = True,
|
207
|
+
squeeze_out_last_dim = True,
|
208
|
+
attn_layers = Encoder(
|
209
|
+
dim = 128,
|
210
|
+
depth = 6,
|
211
|
+
heads = 8
|
212
|
+
)
|
213
|
+
)
|
214
|
+
|
215
|
+
x = torch.randint(0, 20000, (2, 1024))
|
216
|
+
mask = torch.randint(0, 2, (2, 1024)).bool()
|
217
|
+
|
218
|
+
logits = model(x, mask = mask)
|
219
|
+
|
220
|
+
assert logits.shape == (2,)
|
@@ -45,15 +45,18 @@ def default(val, d):
|
|
45
45
|
return val
|
46
46
|
return d() if callable(d) else d
|
47
47
|
|
48
|
-
def
|
48
|
+
def first(it):
|
49
|
+
return it[0]
|
50
|
+
|
51
|
+
def is_empty(x):
|
52
|
+
return len(x) == 0
|
53
|
+
|
54
|
+
def cast_tuple(val, depth = 1):
|
49
55
|
return val if isinstance(val, tuple) else (val,) * depth
|
50
56
|
|
51
57
|
def divisible_by(num, den):
|
52
58
|
return (num % den) == 0
|
53
59
|
|
54
|
-
def is_empty(x):
|
55
|
-
return len(x) == 0
|
56
|
-
|
57
60
|
def maybe(fn):
|
58
61
|
@wraps(fn)
|
59
62
|
def inner(x, *args, **kwargs):
|
@@ -1922,6 +1925,7 @@ class TransformerWrapper(Module):
|
|
1922
1925
|
attn_z_loss_weight = 1e-4,
|
1923
1926
|
average_pool_embed = False,
|
1924
1927
|
use_cls_token = False,
|
1928
|
+
squeeze_out_last_dim = False
|
1925
1929
|
):
|
1926
1930
|
super().__init__()
|
1927
1931
|
|
@@ -2006,6 +2010,10 @@ class TransformerWrapper(Module):
|
|
2006
2010
|
|
2007
2011
|
self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
|
2008
2012
|
|
2013
|
+
# squeeze out last dimension if possible
|
2014
|
+
|
2015
|
+
self.squeeze_out_last_dim = squeeze_out_last_dim
|
2016
|
+
|
2009
2017
|
# whether can do cached kv decoding
|
2010
2018
|
|
2011
2019
|
self.can_cache_kv = self.num_memory_tokens == 0
|
@@ -2173,6 +2181,14 @@ class TransformerWrapper(Module):
|
|
2173
2181
|
else:
|
2174
2182
|
logits = self.to_logits(x)
|
2175
2183
|
|
2184
|
+
# maybe squeeze out last dimension of logits
|
2185
|
+
|
2186
|
+
if self.squeeze_out_last_dim:
|
2187
|
+
logits = tuple((rearrange(t, '... 1 -> ...') if t.shape[-1] == 1 else t) for t in cast_tuple(logits))
|
2188
|
+
|
2189
|
+
if not self.has_multiple_heads:
|
2190
|
+
logits = first(logits)
|
2191
|
+
|
2176
2192
|
# different returns
|
2177
2193
|
|
2178
2194
|
if return_logits_and_embeddings:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/nonautoregressive_wrapper.py
RENAMED
File without changes
|
{x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|