titans-pytorch 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -3,10 +3,11 @@ import math
3
3
  from functools import partial
4
4
 
5
5
  import torch
6
- from torch import nn
6
+ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
+ from einops import repeat
10
11
  from einops.layers.torch import Rearrange
11
12
 
12
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
@@ -48,6 +49,7 @@ class SegmentedAttention(Module):
48
49
  self,
49
50
  dim,
50
51
  segment_len,
52
+ num_persist_mem_tokens,
51
53
  dim_head = 64,
52
54
  heads = 8,
53
55
  ):
@@ -67,6 +69,7 @@ class SegmentedAttention(Module):
67
69
  self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
68
70
  self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
69
71
 
72
+ self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
70
73
 
71
74
  def forward(self, seq):
72
75
  batch, seq_len = seq.shape[:2]
@@ -92,6 +95,15 @@ class SegmentedAttention(Module):
92
95
  q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
93
96
  q, k, v = map(self.split_heads, (q, k, v))
94
97
 
98
+ # take care of persistent memory key / values
99
+
100
+ pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
101
+
102
+ k = cat((pmk, k), dim = -2)
103
+ v = cat((pmv, v), dim = -2)
104
+
105
+ # sdpa
106
+
95
107
  out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
96
108
 
97
109
  out = self.merge_heads(out)
@@ -113,6 +125,7 @@ class MemoryAsContextTransformer(Module):
113
125
  dim,
114
126
  depth,
115
127
  segment_len,
128
+ num_persist_mem_tokens,
116
129
  dim_head = 64,
117
130
  heads = 8,
118
131
  ff_mult = 4,
@@ -127,7 +140,14 @@ class MemoryAsContextTransformer(Module):
127
140
  self.layers = ModuleList([])
128
141
 
129
142
  for _ in range(depth):
130
- attn = SegmentedAttention(dim = dim, dim_head = dim_head, heads = heads, segment_len = segment_len)
143
+ attn = SegmentedAttention(
144
+ dim = dim,
145
+ dim_head = dim_head,
146
+ heads = heads,
147
+ segment_len = segment_len,
148
+ num_persist_mem_tokens = num_persist_mem_tokens
149
+ )
150
+
131
151
  ff = FeedForward(dim = dim, mult = ff_mult)
132
152
 
133
153
  self.layers.append(ModuleList([
@@ -162,6 +182,7 @@ if __name__ == '__main__':
162
182
  num_tokens = 256,
163
183
  dim = 256,
164
184
  depth = 2,
185
+ num_persist_mem_tokens = 16,
165
186
  segment_len = 128,
166
187
  )
167
188
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.26
3
+ Version: 0.0.27
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=63g4IEXbAClN6_FwcZuBPau0ndQ5yRSr0x7TIZDlxxw,4051
3
+ titans_pytorch/mac_transformer.py,sha256=4xUSfGDdVsR-WmeXX7yXoFfybROvNCjOxL_EHDJ_Wlk,4681
4
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.26.dist-info/METADATA,sha256=Dh6ymVAQ-EkWJBt0Z_eWV_UNBQfainhsus6ShywOBkk,3851
7
- titans_pytorch-0.0.26.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.26.dist-info/RECORD,,
6
+ titans_pytorch-0.0.27.dist-info/METADATA,sha256=qJp1IDbphEUfW7EyNvQ7RfmHuvB7SH5h_tlnCVwV4EY,3851
7
+ titans_pytorch-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.27.dist-info/RECORD,,