x-transformers 1.27.10__py3-none-any.whl → 1.27.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -57,6 +57,9 @@ def maybe(fn):
57
57
  return fn(x, *args, **kwargs)
58
58
  return inner
59
59
 
60
+ def at_most_one_of(*bools):
61
+ return sum(map(int, bools)) <= 1
62
+
60
63
  class always():
61
64
  def __init__(self, val):
62
65
  self.val = val
@@ -435,19 +438,15 @@ class RotaryEmbedding(nn.Module):
435
438
 
436
439
  @autocast(enabled = False)
437
440
  def forward(self, t):
438
- device, seq_len = self.inv_freq.device, t.shape[-1]
439
-
440
- t = t.type_as(self.inv_freq)
441
-
442
- t = t / self.interpolation_factor
441
+ max_pos = t.max()+1
443
442
 
444
- freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
443
+ freqs = torch.einsum('i , j -> i j', t.typ_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
445
444
  freqs = torch.cat((freqs, freqs), dim = -1)
446
445
 
447
446
  if not exists(self.scale):
448
447
  return freqs, 1.
449
448
 
450
- power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
449
+ power = (t - (max_pos // 2)) / self.scale_base
451
450
  scale = self.scale ** rearrange(power, 'n -> n 1')
452
451
  scale = torch.cat((scale, scale), dim = -1)
453
452
 
@@ -463,6 +462,7 @@ def rotate_half(x):
463
462
  def apply_rotary_pos_emb(t, freqs, scale = 1):
464
463
  rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
465
464
  freqs = freqs[-seq_len:, :]
465
+ scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
466
466
 
467
467
  if t.ndim == 4 and freqs.ndim == 3:
468
468
  freqs = rearrange(freqs, 'b n d -> b 1 n d')
@@ -1264,8 +1264,6 @@ class AttentionLayers(nn.Module):
1264
1264
  # rotary positions
1265
1265
 
1266
1266
  if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
1267
- max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
1268
-
1269
1267
  maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
1270
1268
  mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
1271
1269
 
@@ -1467,7 +1465,8 @@ class ViTransformerWrapper(nn.Module):
1467
1465
  def forward(
1468
1466
  self,
1469
1467
  img,
1470
- return_embeddings = False
1468
+ return_embeddings = False,
1469
+ return_logits_and_embeddings = False
1471
1470
  ):
1472
1471
  b, p = img.shape[0], self.patch_size
1473
1472
 
@@ -1484,16 +1483,23 @@ class ViTransformerWrapper(nn.Module):
1484
1483
  r = repeat(self.register_tokens, 'n d -> b n d', b = b)
1485
1484
  x, ps = pack((x, r), 'b * d')
1486
1485
 
1487
- x = self.attn_layers(x)
1486
+ embed = self.attn_layers(x)
1488
1487
 
1489
1488
  if self.has_register_tokens:
1490
- x, _ = unpack(x, ps, 'b * d')
1489
+ embed, _ = unpack(embed, ps, 'b * d')
1490
+
1491
+ assert at_most_one_of(return_embeddings, return_logits_and_embeddings)
1491
1492
 
1492
1493
  if not exists(self.mlp_head) or return_embeddings:
1493
- return x
1494
+ return embed
1495
+
1496
+ pooled = embed.mean(dim = -2)
1497
+ logits = self.mlp_head(pooled)
1498
+
1499
+ if not return_logits_and_embeddings:
1500
+ return logits
1494
1501
 
1495
- x = x.mean(dim = -2)
1496
- return self.mlp_head(x)
1502
+ return logits, embed
1497
1503
 
1498
1504
  class TransformerWrapper(nn.Module):
1499
1505
  def __init__(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.27.10
3
+ Version: 1.27.12
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
4
4
  x_transformers/continuous.py,sha256=92Wczoaz6dJalix-e3mdIzW0xyRIx3GlBSgsSQOsJeI,6123
5
5
  x_transformers/dpo.py,sha256=ek9dgiSs05xeCn8ORceOgKy6LJOnNDw-OJDqxAVLecM,2243
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
7
- x_transformers/x_transformers.py,sha256=TPH5PitIzIBWTQdnO8nlctB8poSMvHkBPWcWFolgZAM,63429
7
+ x_transformers/x_transformers.py,sha256=_mb1-TUDTmnuF-WKwg5VlTYYJthFkFC8Q9OfrRD59sQ,63618
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=ulEPep6i5Hl7H-H9vGfdsmHdprUmK8ajB306jViyV2c,8147
10
- x_transformers-1.27.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.27.10.dist-info/METADATA,sha256=v2ZVeG1yd-HPYFbBWYjNL-q4s74asgt7U8VWg4f9Leg,662
12
- x_transformers-1.27.10.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
13
- x_transformers-1.27.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.27.10.dist-info/RECORD,,
10
+ x_transformers-1.27.12.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.27.12.dist-info/METADATA,sha256=vpgAkW1OgQIzDRe2mZaTCHmIMO5-PUv8DO1FSXCLqy0,662
12
+ x_transformers-1.27.12.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
13
+ x_transformers-1.27.12.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.27.12.dist-info/RECORD,,