x-transformers 1.42.25__tar.gz → 1.42.27__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.42.25/x_transformers.egg-info → x_transformers-1.42.27}/PKG-INFO +1 -1
- {x_transformers-1.42.25 → x_transformers-1.42.27}/setup.py +1 -1
- {x_transformers-1.42.25 → x_transformers-1.42.27}/tests/test_x_transformers.py +6 -2
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/x_transformers.py +4 -5
- {x_transformers-1.42.25 → x_transformers-1.42.27/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.42.25 → x_transformers-1.42.27}/LICENSE +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/README.md +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/setup.cfg +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/__init__.py +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/attend.py +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/continuous.py +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/dpo.py +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/xval.py +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers.egg-info/top_level.txt +0 -0
@@ -558,8 +558,10 @@ def test_laser():
|
|
558
558
|
|
559
559
|
model(x)
|
560
560
|
|
561
|
+
@pytest.mark.parametrize('self_attn_custom_pos', (True, False))
|
561
562
|
@pytest.mark.parametrize('cross_attn_rotary', (True, False))
|
562
563
|
def test_cross_attn_rotary(
|
564
|
+
self_attn_custom_pos: bool,
|
563
565
|
cross_attn_rotary: bool
|
564
566
|
):
|
565
567
|
|
@@ -577,12 +579,14 @@ def test_cross_attn_rotary(
|
|
577
579
|
cross_attn_dim_context = 512
|
578
580
|
)
|
579
581
|
|
580
|
-
|
582
|
+
pos = torch.arange(64) if self_attn_custom_pos else None
|
583
|
+
context_pos = torch.arange(128) if cross_attn_rotary else None
|
581
584
|
|
582
585
|
embed = model(
|
583
586
|
x = x,
|
584
587
|
mask = mask,
|
585
588
|
context = context,
|
586
|
-
|
589
|
+
pos = pos,
|
590
|
+
context_pos = context_pos,
|
587
591
|
context_mask = context_mask
|
588
592
|
)
|
@@ -51,8 +51,8 @@ def default(val, d):
|
|
51
51
|
return val
|
52
52
|
return d() if callable(d) else d
|
53
53
|
|
54
|
-
def first(it):
|
55
|
-
return it[0]
|
54
|
+
def first(it, default = None):
|
55
|
+
return it[0] if len(it) > 0 else default
|
56
56
|
|
57
57
|
def is_empty(x):
|
58
58
|
return len(x) == 0
|
@@ -1077,7 +1077,7 @@ class Attention(Module):
|
|
1077
1077
|
logit_softclamp_value = 50.,
|
1078
1078
|
neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
|
1079
1079
|
neutreno_alpha = 0.4,
|
1080
|
-
learned_value_residual_mix =
|
1080
|
+
learned_value_residual_mix = True,
|
1081
1081
|
laser = False, # https://arxiv.org/abs/2411.03493v1
|
1082
1082
|
laser_softclamp_value = 15.,
|
1083
1083
|
onnxable = False,
|
@@ -1357,7 +1357,6 @@ class Attention(Module):
|
|
1357
1357
|
k = k * self.qk_norm_k_scale
|
1358
1358
|
|
1359
1359
|
if exists(rotary_pos_emb):
|
1360
|
-
|
1361
1360
|
freqs, xpos_scale = rotary_pos_emb
|
1362
1361
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
1363
1362
|
|
@@ -1989,7 +1988,7 @@ class AttentionLayers(Module):
|
|
1989
1988
|
|
1990
1989
|
if exists(self.rotary_pos_emb):
|
1991
1990
|
if not exists(rotary_pos_emb):
|
1992
|
-
maybe_mem = mems
|
1991
|
+
maybe_mem = first(mems, None) # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
|
1993
1992
|
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
|
1994
1993
|
|
1995
1994
|
if not exists(pos):
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/nonautoregressive_wrapper.py
RENAMED
File without changes
|
{x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|