titans-pytorch 0.0.39__py3-none-any.whl → 0.0.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -94,6 +94,7 @@ class SegmentedAttention(Module):
94
94
  num_longterm_mem_tokens = 0,
95
95
  dim_head = 64,
96
96
  heads = 8,
97
+ accept_value_residual = False,
97
98
  attend_kwargs: dict = dict()
98
99
  ):
99
100
  super().__init__()
@@ -108,6 +109,12 @@ class SegmentedAttention(Module):
108
109
  self.to_qkv = LinearNoBias(dim, dim_inner * 3)
109
110
  self.to_out = LinearNoBias(dim_inner, dim)
110
111
 
112
+ self.to_learned_v_mix = nn.Sequential(
113
+ nn.Linear(dim, heads),
114
+ Rearrange('b n h -> b h n 1'),
115
+ nn.Sigmoid()
116
+ ) if accept_value_residual else None
117
+
111
118
  self.segment_len = segment_len
112
119
  self.num_longterm_mem_tokens = num_longterm_mem_tokens
113
120
 
@@ -118,7 +125,13 @@ class SegmentedAttention(Module):
118
125
 
119
126
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
120
127
 
121
- def forward(self, seq):
128
+ def forward(
129
+ self,
130
+ seq,
131
+ value_residual = None
132
+ ):
133
+ assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
134
+
122
135
  segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
123
136
  total_segment_len = segment_len + num_longterm_mem_tokens
124
137
 
@@ -136,6 +149,14 @@ class SegmentedAttention(Module):
136
149
  q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
137
150
  q, k, v = map(self.split_heads, (q, k, v))
138
151
 
152
+ # value residual
153
+
154
+ orig_v = v
155
+
156
+ if exists(self.to_learned_v_mix):
157
+ mix = self.to_learned_v_mix(seq)
158
+ v = v.lerp(value_residual, mix)
159
+
139
160
  # take care of persistent memory key / values
140
161
 
141
162
  pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
@@ -159,7 +180,7 @@ class SegmentedAttention(Module):
159
180
 
160
181
  out = inverse_segment(out)
161
182
 
162
- return out
183
+ return out, orig_v
163
184
 
164
185
  # MAC transformer
165
186
 
@@ -171,6 +192,7 @@ class MemoryAsContextTransformer(Module):
171
192
  dim,
172
193
  depth,
173
194
  segment_len,
195
+ neural_memory_segment_len = None,
174
196
  num_longterm_mem_tokens = 0,
175
197
  num_persist_mem_tokens = 0,
176
198
  dim_head = 64,
@@ -200,7 +222,9 @@ class MemoryAsContextTransformer(Module):
200
222
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
201
223
 
202
224
  self.layers = ModuleList([])
225
+
203
226
  self.neural_mem_layers = ModuleList([])
227
+ neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
204
228
 
205
229
  layers = tuple(range(1, depth + 1))
206
230
 
@@ -210,6 +234,7 @@ class MemoryAsContextTransformer(Module):
210
234
  assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
211
235
 
212
236
  for layer in layers:
237
+ is_first = layer == 1
213
238
 
214
239
  # neural memory
215
240
 
@@ -220,7 +245,7 @@ class MemoryAsContextTransformer(Module):
220
245
 
221
246
  mem = NeuralMemory(
222
247
  dim = dim,
223
- chunk_size = num_longterm_mem_tokens + segment_len,
248
+ chunk_size = neural_memory_segment_len,
224
249
  **neural_memory_kwargs
225
250
  )
226
251
 
@@ -235,6 +260,7 @@ class MemoryAsContextTransformer(Module):
235
260
  dim_head = dim_head,
236
261
  heads = heads,
237
262
  segment_len = segment_len,
263
+ accept_value_residual = not is_first,
238
264
  num_longterm_mem_tokens = num_longterm_mem_tokens,
239
265
  num_persist_mem_tokens = num_persist_mem_tokens
240
266
  )
@@ -285,6 +311,10 @@ class MemoryAsContextTransformer(Module):
285
311
  pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
286
312
  x = x + pos_emb[:x.shape[-2]]
287
313
 
314
+ # value residual
315
+
316
+ value_residual = None
317
+
288
318
  # expand and reduce streams for hyper connections
289
319
 
290
320
  x = self.expand_streams(x)
@@ -295,7 +325,9 @@ class MemoryAsContextTransformer(Module):
295
325
  x = maybe_neural_mem(x)
296
326
 
297
327
 
298
- x = attn(x)
328
+ x, values = attn(x, value_residual = value_residual)
329
+
330
+ value_residual = default(value_residual, values)
299
331
 
300
332
  x = ff(x)
301
333
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.39
3
+ Version: 0.0.41
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=h58sHfufxMnSXZXyWuW-KBwzq8xwBYmFjU2XtOjUixk,8512
3
+ titans_pytorch/mac_transformer.py,sha256=szHg8m97ew7OtipVih3pkOe1jsvhBnqvohJVJBrU5ks,9452
4
4
  titans_pytorch/titans.py,sha256=bv2Ceq-_4nNb5FNx4hLd2inC93m5MmJxO2-Mbf6PKK4,14378
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.39.dist-info/METADATA,sha256=3KD2hmJ-uOyQ87Z3VB6JfaKtDcLBnoKA8037DpzJuPE,3968
7
- titans_pytorch-0.0.39.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.39.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.39.dist-info/RECORD,,
6
+ titans_pytorch-0.0.41.dist-info/METADATA,sha256=RXjWshmGlzoA-OYaiI6GwieXBwbsq5ZR_hk-PM66aeg,3968
7
+ titans_pytorch-0.0.41.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.41.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.41.dist-info/RECORD,,