titans-pytorch 0.0.26__tar.gz → 0.0.27__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.26
3
+ Version: 0.0.27
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.26"
3
+ version = "0.0.27"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -3,10 +3,11 @@ import math
3
3
  from functools import partial
4
4
 
5
5
  import torch
6
- from torch import nn
6
+ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
+ from einops import repeat
10
11
  from einops.layers.torch import Rearrange
11
12
 
12
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
@@ -48,6 +49,7 @@ class SegmentedAttention(Module):
48
49
  self,
49
50
  dim,
50
51
  segment_len,
52
+ num_persist_mem_tokens,
51
53
  dim_head = 64,
52
54
  heads = 8,
53
55
  ):
@@ -67,6 +69,7 @@ class SegmentedAttention(Module):
67
69
  self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
68
70
  self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
69
71
 
72
+ self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
70
73
 
71
74
  def forward(self, seq):
72
75
  batch, seq_len = seq.shape[:2]
@@ -92,6 +95,15 @@ class SegmentedAttention(Module):
92
95
  q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
93
96
  q, k, v = map(self.split_heads, (q, k, v))
94
97
 
98
+ # take care of persistent memory key / values
99
+
100
+ pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
101
+
102
+ k = cat((pmk, k), dim = -2)
103
+ v = cat((pmv, v), dim = -2)
104
+
105
+ # sdpa
106
+
95
107
  out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
96
108
 
97
109
  out = self.merge_heads(out)
@@ -113,6 +125,7 @@ class MemoryAsContextTransformer(Module):
113
125
  dim,
114
126
  depth,
115
127
  segment_len,
128
+ num_persist_mem_tokens,
116
129
  dim_head = 64,
117
130
  heads = 8,
118
131
  ff_mult = 4,
@@ -127,7 +140,14 @@ class MemoryAsContextTransformer(Module):
127
140
  self.layers = ModuleList([])
128
141
 
129
142
  for _ in range(depth):
130
- attn = SegmentedAttention(dim = dim, dim_head = dim_head, heads = heads, segment_len = segment_len)
143
+ attn = SegmentedAttention(
144
+ dim = dim,
145
+ dim_head = dim_head,
146
+ heads = heads,
147
+ segment_len = segment_len,
148
+ num_persist_mem_tokens = num_persist_mem_tokens
149
+ )
150
+
131
151
  ff = FeedForward(dim = dim, mult = ff_mult)
132
152
 
133
153
  self.layers.append(ModuleList([
@@ -162,6 +182,7 @@ if __name__ == '__main__':
162
182
  num_tokens = 256,
163
183
  dim = 256,
164
184
  depth = 2,
185
+ num_persist_mem_tokens = 16,
165
186
  segment_len = 128,
166
187
  )
167
188
 
File without changes