x-transformers 1.27.10__py3-none-any.whl → 1.27.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +21 -15
- {x_transformers-1.27.10.dist-info → x_transformers-1.27.12.dist-info}/METADATA +1 -1
- {x_transformers-1.27.10.dist-info → x_transformers-1.27.12.dist-info}/RECORD +6 -6
- {x_transformers-1.27.10.dist-info → x_transformers-1.27.12.dist-info}/LICENSE +0 -0
- {x_transformers-1.27.10.dist-info → x_transformers-1.27.12.dist-info}/WHEEL +0 -0
- {x_transformers-1.27.10.dist-info → x_transformers-1.27.12.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -57,6 +57,9 @@ def maybe(fn):
|
|
57
57
|
return fn(x, *args, **kwargs)
|
58
58
|
return inner
|
59
59
|
|
60
|
+
def at_most_one_of(*bools):
|
61
|
+
return sum(map(int, bools)) <= 1
|
62
|
+
|
60
63
|
class always():
|
61
64
|
def __init__(self, val):
|
62
65
|
self.val = val
|
@@ -435,19 +438,15 @@ class RotaryEmbedding(nn.Module):
|
|
435
438
|
|
436
439
|
@autocast(enabled = False)
|
437
440
|
def forward(self, t):
|
438
|
-
|
439
|
-
|
440
|
-
t = t.type_as(self.inv_freq)
|
441
|
-
|
442
|
-
t = t / self.interpolation_factor
|
441
|
+
max_pos = t.max()+1
|
443
442
|
|
444
|
-
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
443
|
+
freqs = torch.einsum('i , j -> i j', t.typ_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
|
445
444
|
freqs = torch.cat((freqs, freqs), dim = -1)
|
446
445
|
|
447
446
|
if not exists(self.scale):
|
448
447
|
return freqs, 1.
|
449
448
|
|
450
|
-
power = (
|
449
|
+
power = (t - (max_pos // 2)) / self.scale_base
|
451
450
|
scale = self.scale ** rearrange(power, 'n -> n 1')
|
452
451
|
scale = torch.cat((scale, scale), dim = -1)
|
453
452
|
|
@@ -463,6 +462,7 @@ def rotate_half(x):
|
|
463
462
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
464
463
|
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
465
464
|
freqs = freqs[-seq_len:, :]
|
465
|
+
scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
466
466
|
|
467
467
|
if t.ndim == 4 and freqs.ndim == 3:
|
468
468
|
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
@@ -1264,8 +1264,6 @@ class AttentionLayers(nn.Module):
|
|
1264
1264
|
# rotary positions
|
1265
1265
|
|
1266
1266
|
if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
|
1267
|
-
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
|
1268
|
-
|
1269
1267
|
maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
|
1270
1268
|
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
|
1271
1269
|
|
@@ -1467,7 +1465,8 @@ class ViTransformerWrapper(nn.Module):
|
|
1467
1465
|
def forward(
|
1468
1466
|
self,
|
1469
1467
|
img,
|
1470
|
-
return_embeddings = False
|
1468
|
+
return_embeddings = False,
|
1469
|
+
return_logits_and_embeddings = False
|
1471
1470
|
):
|
1472
1471
|
b, p = img.shape[0], self.patch_size
|
1473
1472
|
|
@@ -1484,16 +1483,23 @@ class ViTransformerWrapper(nn.Module):
|
|
1484
1483
|
r = repeat(self.register_tokens, 'n d -> b n d', b = b)
|
1485
1484
|
x, ps = pack((x, r), 'b * d')
|
1486
1485
|
|
1487
|
-
|
1486
|
+
embed = self.attn_layers(x)
|
1488
1487
|
|
1489
1488
|
if self.has_register_tokens:
|
1490
|
-
|
1489
|
+
embed, _ = unpack(embed, ps, 'b * d')
|
1490
|
+
|
1491
|
+
assert at_most_one_of(return_embeddings, return_logits_and_embeddings)
|
1491
1492
|
|
1492
1493
|
if not exists(self.mlp_head) or return_embeddings:
|
1493
|
-
return
|
1494
|
+
return embed
|
1495
|
+
|
1496
|
+
pooled = embed.mean(dim = -2)
|
1497
|
+
logits = self.mlp_head(pooled)
|
1498
|
+
|
1499
|
+
if not return_logits_and_embeddings:
|
1500
|
+
return logits
|
1494
1501
|
|
1495
|
-
|
1496
|
-
return self.mlp_head(x)
|
1502
|
+
return logits, embed
|
1497
1503
|
|
1498
1504
|
class TransformerWrapper(nn.Module):
|
1499
1505
|
def __init__(
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
|
|
4
4
|
x_transformers/continuous.py,sha256=92Wczoaz6dJalix-e3mdIzW0xyRIx3GlBSgsSQOsJeI,6123
|
5
5
|
x_transformers/dpo.py,sha256=ek9dgiSs05xeCn8ORceOgKy6LJOnNDw-OJDqxAVLecM,2243
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=_mb1-TUDTmnuF-WKwg5VlTYYJthFkFC8Q9OfrRD59sQ,63618
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=ulEPep6i5Hl7H-H9vGfdsmHdprUmK8ajB306jViyV2c,8147
|
10
|
-
x_transformers-1.27.
|
11
|
-
x_transformers-1.27.
|
12
|
-
x_transformers-1.27.
|
13
|
-
x_transformers-1.27.
|
14
|
-
x_transformers-1.27.
|
10
|
+
x_transformers-1.27.12.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.27.12.dist-info/METADATA,sha256=vpgAkW1OgQIzDRe2mZaTCHmIMO5-PUv8DO1FSXCLqy0,662
|
12
|
+
x_transformers-1.27.12.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
13
|
+
x_transformers-1.27.12.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.27.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|