titans-pytorch 0.0.31__py3-none-any.whl → 0.0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,6 +17,10 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
17
17
  from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
18
  from rotary_embedding_torch import RotaryEmbedding
19
19
 
20
+ # proposed neural memory
21
+
22
+ from titans_pytorch.titans import NeuralMemory
23
+
20
24
  # constants
21
25
 
22
26
  LinearNoBias = partial(Linear, bias = False)
@@ -46,13 +50,20 @@ def pad_and_segment_with_inverse(seq, segment_len):
46
50
  next_seq_len_mult = round_up_multiple(seq_len, segment_len)
47
51
 
48
52
  padding = next_seq_len_mult - seq_len
49
- seq = F.pad(seq, (0, 0, 0, padding))
53
+ needs_pad = padding > 0
54
+
55
+ if needs_pad:
56
+ seq = F.pad(seq, (0, 0, 0, padding))
50
57
 
51
58
  seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
52
59
 
53
60
  def inverse(out):
54
61
  out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
55
- return out[:, :-padding]
62
+
63
+ if needs_pad:
64
+ out = out[:, :-padding]
65
+
66
+ return out
56
67
 
57
68
  return seq, inverse
58
69
 
@@ -161,7 +172,9 @@ class MemoryAsContextTransformer(Module):
161
172
  dim_head = 64,
162
173
  heads = 8,
163
174
  ff_mult = 4,
164
- num_residual_streams = 4
175
+ num_residual_streams = 4,
176
+ neural_memory_kwargs: dict = dict(),
177
+ neural_memory_layers: tuple[int, ...] | None = None,
165
178
  ):
166
179
  super().__init__()
167
180
 
@@ -181,8 +194,25 @@ class MemoryAsContextTransformer(Module):
181
194
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
182
195
 
183
196
  self.layers = ModuleList([])
197
+ self.neural_mem_layers = ModuleList([])
198
+
199
+ layers = tuple(range(1, depth + 1))
200
+ neural_memory_layers = set(default(neural_memory_layers, layers))
201
+
202
+ for layer in layers:
203
+
204
+ # neural memory
205
+
206
+ mem = None
207
+
208
+ if num_longterm_mem_tokens > 0 and layer in neural_memory_layers:
209
+ mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
210
+ mem = init_hyper_conn(dim = dim, branch = mem)
211
+
212
+ self.neural_mem_layers.append(mem)
213
+
214
+ # attention and feedforward
184
215
 
185
- for _ in range(depth):
186
216
  attn = SegmentedAttention(
187
217
  dim = dim,
188
218
  dim_head = dim_head,
@@ -203,7 +233,14 @@ class MemoryAsContextTransformer(Module):
203
233
 
204
234
  self.to_logits = LinearNoBias(dim, num_tokens)
205
235
 
206
- def forward(self, x):
236
+ def forward(
237
+ self,
238
+ x,
239
+ return_loss = False
240
+ ):
241
+
242
+ if return_loss:
243
+ x, labels = x[:, :-1], x[:, 1:]
207
244
 
208
245
  # math
209
246
 
@@ -221,7 +258,7 @@ class MemoryAsContextTransformer(Module):
221
258
  x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
222
259
 
223
260
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
224
- x = torch.cat((mems, x), dim = -2)
261
+ x = cat((mems, x), dim = -2)
225
262
 
226
263
  x = inverse_segment(x)
227
264
 
@@ -235,8 +272,27 @@ class MemoryAsContextTransformer(Module):
235
272
 
236
273
  x = self.expand_streams(x)
237
274
 
238
- for attn, ff in self.layers:
275
+ for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
276
+
277
+ if exists(maybe_neural_mem):
278
+ batch_streams = x.shape[0]
279
+
280
+ x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
281
+
282
+ longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
283
+
284
+ longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
285
+
286
+ longterm_mems = maybe_neural_mem(longterm_mems)
287
+
288
+ longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
289
+
290
+ x = cat((longterm_mems, x), dim = -2)
291
+
292
+ x = inverse_segment(x)
293
+
239
294
  x = attn(x)
295
+
240
296
  x = ff(x)
241
297
 
242
298
  x = self.reduce_streams(x)
@@ -245,7 +301,7 @@ class MemoryAsContextTransformer(Module):
245
301
 
246
302
  x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
247
303
 
248
- x = x[:, self.num_longterm_mem_tokens:]
304
+ x = x[:, num_longterm_mem_tokens:]
249
305
 
250
306
  x = inverse_segment(x)
251
307
 
@@ -253,4 +309,9 @@ class MemoryAsContextTransformer(Module):
253
309
 
254
310
  x = self.norm(x)
255
311
 
256
- return self.to_logits(x)
312
+ logits = self.to_logits(x)
313
+
314
+ if not return_loss:
315
+ return logits
316
+
317
+ return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.31
3
+ Version: 0.0.34
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=bq5RbCgA0GWLFHTrDTIKUSQhkkuCkdjEykOwjfHDs0M,6747
3
+ titans_pytorch/mac_transformer.py,sha256=FGShQHD-dQQdQKKzvNS_jTC_FcikdqO_s3ZKOKfr_9E,8502
4
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.31.dist-info/METADATA,sha256=bN1fVL2S_vML1oqLIA92tvBhkVvnpQN11fU4e1QVI4s,3938
7
- titans_pytorch-0.0.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.31.dist-info/RECORD,,
6
+ titans_pytorch-0.0.34.dist-info/METADATA,sha256=CNqv_jMqk7yj15IpDn2O3jBdVe4wtrSVkht7mk0wW_E,3938
7
+ titans_pytorch-0.0.34.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.34.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.34.dist-info/RECORD,,