x-transformers 1.23.4__py3-none-any.whl → 1.23.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,19 @@
1
- import torch
2
- from packaging import version
3
-
4
- if version.parse(torch.__version__) >= version.parse('2.0.0'):
5
- from einops._torch_specific import allow_ops_in_compiled_graph
6
- allow_ops_in_compiled_graph()
7
-
8
- from x_transformers.x_transformers import XTransformer, Encoder, Decoder, CrossAttender, Attention, TransformerWrapper, ViTransformerWrapper, ContinuousTransformerWrapper
1
+ from x_transformers.x_transformers import (
2
+ XTransformer,
3
+ Encoder,
4
+ Decoder,
5
+ CrossAttender,
6
+ Attention,
7
+ TransformerWrapper,
8
+ ViTransformerWrapper
9
+ )
9
10
 
10
11
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
11
12
  from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
12
- from x_transformers.continuous_autoregressive_wrapper import ContinuousAutoregressiveWrapper
13
+
14
+ from x_transformers.continuous_autoregressive_wrapper import (
15
+ ContinuousTransformerWrapper,
16
+ ContinuousAutoregressiveWrapper
17
+ )
18
+
13
19
  from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper
@@ -2,11 +2,147 @@ import torch
2
2
  from torch import nn
3
3
  import torch.nn.functional as F
4
4
 
5
+ from x_transformers.x_transformers import (
6
+ AttentionLayers,
7
+ ScaledSinusoidalEmbedding,
8
+ AbsolutePositionalEmbedding
9
+ )
10
+
11
+ # helper functions
12
+
5
13
  def exists(val):
6
14
  return val is not None
7
15
 
16
+ def default(val, d):
17
+ if exists(val):
18
+ return val
19
+ return d() if callable(d) else d
20
+
21
+ # main classes
22
+
23
+ class ContinuousTransformerWrapper(nn.Module):
24
+ def __init__(
25
+ self,
26
+ *,
27
+ max_seq_len,
28
+ attn_layers: AttentionLayers,
29
+ dim_in = None,
30
+ dim_out = None,
31
+ emb_dim = None,
32
+ max_mem_len = 0,
33
+ num_memory_tokens = None,
34
+ post_emb_norm = False,
35
+ emb_dropout = 0.,
36
+ use_abs_pos_emb = True,
37
+ scaled_sinu_pos_emb = False
38
+ ):
39
+ super().__init__()
40
+ dim = attn_layers.dim
41
+
42
+ self.max_seq_len = max_seq_len
43
+
44
+ self.max_mem_len = max_mem_len
45
+
46
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
47
+ self.pos_emb = always(0)
48
+ elif scaled_sinu_pos_emb:
49
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
50
+ else:
51
+ self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
52
+
53
+ self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
54
+ self.emb_dropout = nn.Dropout(emb_dropout)
55
+
56
+ # memory tokens
57
+
58
+ num_memory_tokens = default(num_memory_tokens, 0)
59
+ self.has_memory_tokens = num_memory_tokens > 0
60
+
61
+ if num_memory_tokens > 0:
62
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
63
+
64
+ # attention layers
65
+
66
+ self.attn_layers = attn_layers
67
+
68
+ # project in and out
69
+
70
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
71
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
72
+
73
+ def forward(
74
+ self,
75
+ x,
76
+ return_embeddings = False,
77
+ return_intermediates = False,
78
+ return_mems = False,
79
+ mask = None,
80
+ return_attn = False,
81
+ mems = None,
82
+ pos = None,
83
+ prepend_embeds = None,
84
+ **kwargs
85
+ ):
86
+ batch = x.shape[0]
87
+
88
+ x = self.project_in(x)
89
+ x = x + self.pos_emb(x, pos = pos)
90
+
91
+ x = self.post_emb_norm(x)
92
+
93
+ # memory tokens
94
+
95
+ if self.has_memory_tokens:
96
+ m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
97
+ x, mem_ps = pack([m, x], 'b * d')
98
+
99
+ if exists(mask):
100
+ num_mems = m.shape[-2]
101
+ mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
102
+
103
+ # whether to append embeds, as in PaLI, for image embeddings
104
+
105
+ if exists(prepend_embeds):
106
+ _, prepend_dim = prepend_embeds.shape[1:]
107
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
108
+
109
+ x = torch.cat((prepend_embeds, x), dim = -2)
110
+
111
+ x = self.emb_dropout(x)
112
+
113
+ # attention layers
114
+
115
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
116
+
117
+ # splice out memory tokens
118
+
119
+ if self.has_memory_tokens:
120
+ m, x = unpack(x, mem_ps, 'b * d')
121
+ intermediates.memory_tokens = m
122
+
123
+ out = self.project_out(x) if not return_embeddings else x
124
+
125
+ if return_intermediates:
126
+ return out, intermediates
127
+
128
+ if return_mems:
129
+ hiddens = intermediates.hiddens
130
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
131
+ return out, new_mems
132
+
133
+ if return_attn:
134
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
135
+ return out, attn_maps
136
+
137
+ return out
138
+
8
139
  class ContinuousAutoregressiveWrapper(nn.Module):
9
- def __init__(self, net, ignore_index = -100, pad_value = 0):
140
+ def __init__(
141
+ self,
142
+ net: ContinuousTransformerWrapper,
143
+ ignore_index = -100,
144
+ pad_value = 0
145
+ ):
10
146
  super().__init__()
11
147
  self.net = net
12
148
  self.max_seq_len = net.max_seq_len
@@ -6,7 +6,6 @@ from torch import nn, einsum, Tensor
6
6
  import torch.nn.functional as F
7
7
 
8
8
  from functools import partial, wraps
9
- from inspect import isfunction
10
9
  from collections import namedtuple
11
10
  from dataclasses import dataclass
12
11
  from typing import List, Callable, Optional
@@ -28,6 +27,7 @@ class LayerIntermediates:
28
27
  layer_hiddens: Optional[List[Tensor]] = None
29
28
  attn_z_loss: Optional[Tensor] = None
30
29
  mems: Optional[Tensor] = None
30
+ memory_tokens: Optional[Tensor] = None
31
31
 
32
32
  # helpers
33
33
 
@@ -37,7 +37,7 @@ def exists(val):
37
37
  def default(val, d):
38
38
  if exists(val):
39
39
  return val
40
- return d() if isfunction(d) else d
40
+ return d() if callable(d) else d
41
41
 
42
42
  def cast_tuple(val, depth):
43
43
  return val if isinstance(val, tuple) else (val,) * depth
@@ -1339,7 +1339,7 @@ class ViTransformerWrapper(nn.Module):
1339
1339
  *,
1340
1340
  image_size,
1341
1341
  patch_size,
1342
- attn_layers,
1342
+ attn_layers: Encoder,
1343
1343
  channels = 3,
1344
1344
  num_classes = None,
1345
1345
  post_emb_norm = False,
@@ -1347,7 +1347,6 @@ class ViTransformerWrapper(nn.Module):
1347
1347
  emb_dropout = 0.
1348
1348
  ):
1349
1349
  super().__init__()
1350
- assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
1351
1350
  assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
1352
1351
  dim = attn_layers.dim
1353
1352
  num_patches = (image_size // patch_size) ** 2
@@ -1413,7 +1412,7 @@ class TransformerWrapper(nn.Module):
1413
1412
  *,
1414
1413
  num_tokens,
1415
1414
  max_seq_len,
1416
- attn_layers,
1415
+ attn_layers: AttentionLayers,
1417
1416
  emb_dim = None,
1418
1417
  max_mem_len = 0,
1419
1418
  shift_mem_down = 0,
@@ -1430,7 +1429,6 @@ class TransformerWrapper(nn.Module):
1430
1429
  attn_z_loss_weight = 1e-4,
1431
1430
  ):
1432
1431
  super().__init__()
1433
- assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1434
1432
 
1435
1433
  dim = attn_layers.dim
1436
1434
  emb_dim = default(emb_dim, dim)
@@ -1576,6 +1574,8 @@ class TransformerWrapper(nn.Module):
1576
1574
 
1577
1575
  mem, x = unpack(x, mem_packed_shape, 'b * d')
1578
1576
 
1577
+ intermediates.memory_tokens = mem
1578
+
1579
1579
  if exists(mem_every):
1580
1580
  x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
1581
1581
 
@@ -1612,123 +1612,6 @@ class TransformerWrapper(nn.Module):
1612
1612
 
1613
1613
  return out
1614
1614
 
1615
- class ContinuousTransformerWrapper(nn.Module):
1616
- def __init__(
1617
- self,
1618
- *,
1619
- max_seq_len,
1620
- attn_layers,
1621
- dim_in = None,
1622
- dim_out = None,
1623
- emb_dim = None,
1624
- max_mem_len = 0,
1625
- num_memory_tokens = None,
1626
- post_emb_norm = False,
1627
- emb_dropout = 0.,
1628
- use_abs_pos_emb = True,
1629
- scaled_sinu_pos_emb = False
1630
- ):
1631
- super().__init__()
1632
- assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1633
-
1634
- dim = attn_layers.dim
1635
-
1636
- self.max_seq_len = max_seq_len
1637
-
1638
- self.max_mem_len = max_mem_len
1639
-
1640
- if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
1641
- self.pos_emb = always(0)
1642
- elif scaled_sinu_pos_emb:
1643
- self.pos_emb = ScaledSinusoidalEmbedding(dim)
1644
- else:
1645
- self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
1646
-
1647
- self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
1648
- self.emb_dropout = nn.Dropout(emb_dropout)
1649
-
1650
- # memory tokens
1651
-
1652
- num_memory_tokens = default(num_memory_tokens, 0)
1653
- self.has_memory_tokens = num_memory_tokens > 0
1654
-
1655
- if num_memory_tokens > 0:
1656
- self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
1657
-
1658
- # attention layers
1659
-
1660
- self.attn_layers = attn_layers
1661
-
1662
- # project in and out
1663
-
1664
- self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1665
- self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
1666
-
1667
- def forward(
1668
- self,
1669
- x,
1670
- return_embeddings = False,
1671
- return_intermediates = False,
1672
- return_mems = False,
1673
- mask = None,
1674
- return_attn = False,
1675
- mems = None,
1676
- pos = None,
1677
- prepend_embeds = None,
1678
- **kwargs
1679
- ):
1680
- batch = x.shape[0]
1681
-
1682
- x = self.project_in(x)
1683
- x = x + self.pos_emb(x, pos = pos)
1684
-
1685
- x = self.post_emb_norm(x)
1686
-
1687
- # memory tokens
1688
-
1689
- if self.has_memory_tokens:
1690
- m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
1691
- x, mem_ps = pack([m, x], 'b * d')
1692
-
1693
- if exists(mask):
1694
- num_mems = m.shape[-2]
1695
- mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
1696
-
1697
- # whether to append embeds, as in PaLI, for image embeddings
1698
-
1699
- if exists(prepend_embeds):
1700
- _, prepend_dim = prepend_embeds.shape[1:]
1701
- assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
1702
-
1703
- x = torch.cat((prepend_embeds, x), dim = -2)
1704
-
1705
- x = self.emb_dropout(x)
1706
-
1707
- # attention layers
1708
-
1709
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
1710
-
1711
- # splice out memory tokens
1712
-
1713
- if self.has_memory_tokens:
1714
- m, x = unpack(x, mem_ps, 'b * d')
1715
-
1716
- out = self.project_out(x) if not return_embeddings else x
1717
-
1718
- if return_intermediates:
1719
- return out, intermediates
1720
-
1721
- if return_mems:
1722
- hiddens = intermediates.hiddens
1723
- new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
1724
- return out, new_mems
1725
-
1726
- if return_attn:
1727
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1728
- return out, attn_maps
1729
-
1730
- return out
1731
-
1732
1615
  class XTransformer(nn.Module):
1733
1616
  def __init__(
1734
1617
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.23.4
3
+ Version: 1.23.6
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -0,0 +1,12 @@
1
+ x_transformers/__init__.py,sha256=LLbGkiUKu4nR4YOISh-1gPEZYs8R3I2KMwMZHMQ2YkU,538
2
+ x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,11263
3
+ x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
4
+ x_transformers/continuous_autoregressive_wrapper.py,sha256=lGqE5vFaDeuLFc7b-dAQ0hx3H4dHFK_yD4-tZQZ7vqQ,5337
5
+ x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
+ x_transformers/x_transformers.py,sha256=8zfU1iqrd6AwT-L23jkK7tpunxfjLi4HKgbZ6evBWKU,58416
7
+ x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
+ x_transformers-1.23.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
+ x_transformers-1.23.6.dist-info/METADATA,sha256=UnmCziZTo4xYZJjAdgtRyWLytboEAR0JmyK2i0hwMvc,661
10
+ x_transformers-1.23.6.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
+ x_transformers-1.23.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
+ x_transformers-1.23.6.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
2
- x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,11263
3
- x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
4
- x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
5
- x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=jjuh7MLIlV4pHsJHONFlQy96QTdI5VonR9cc8FVs4J8,62009
7
- x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
- x_transformers-1.23.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
- x_transformers-1.23.4.dist-info/METADATA,sha256=51dyPgX1be0-AH9PktXW_fme4GejI0YgZXbviy97j0Q,661
10
- x_transformers-1.23.4.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
- x_transformers-1.23.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
- x_transformers-1.23.4.dist-info/RECORD,,