titans-pytorch 0.0.56__py3-none-any.whl → 0.0.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -285,6 +285,49 @@ class SegmentedAttention(Module):
285
285
 
286
286
  return out, orig_v
287
287
 
288
+ # Attention + Neural Memory gating configuration, as depicted in Figure 2
289
+
290
+ class NeuralMemoryGatingWrapper(Module):
291
+ def __init__(
292
+ self,
293
+ dim,
294
+ attn: SegmentedAttention,
295
+ neural_mem: NeuralMemory | None = None,
296
+ gate_attn_output = True
297
+ ):
298
+ super().__init__()
299
+ self.attn = attn
300
+ self.neural_mem = neural_mem
301
+ self.gate_attn_output = gate_attn_output
302
+
303
+ def forward(
304
+ self,
305
+ seq,
306
+ *args,
307
+ **kwargs
308
+ ):
309
+ batch, seq_len = seq.shape[:2]
310
+ mem = self.neural_mem
311
+
312
+ if not exists(mem):
313
+ return self.attn(seq, *args, **kwargs), 0.
314
+
315
+ # initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory
316
+
317
+ retrieved, kv_aux_loss = mem(seq, return_aux_kv_loss = True)
318
+
319
+ if not self.gate_attn_output:
320
+ seq = seq + retrieved
321
+
322
+ # attention
323
+
324
+ attn_out, values = self.attn(seq, *args, **kwargs)
325
+
326
+ if self.gate_attn_output:
327
+ attn_out = attn_out * retrieved.sigmoid()
328
+
329
+ return (attn_out, values), kv_aux_loss
330
+
288
331
  # MAC transformer
289
332
 
290
333
  class MemoryAsContextTransformer(Module):
@@ -296,6 +339,7 @@ class MemoryAsContextTransformer(Module):
296
339
  depth,
297
340
  segment_len,
298
341
  neural_memory_segment_len = None,
342
+ neural_mem_gate_attn_output = True,
299
343
  num_longterm_mem_tokens = 0,
300
344
  num_persist_mem_tokens = 0,
301
345
  dim_head = 64,
@@ -328,7 +372,6 @@ class MemoryAsContextTransformer(Module):
328
372
 
329
373
  self.layers = ModuleList([])
330
374
 
331
- self.neural_mem_layers = ModuleList([])
332
375
  self.neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
333
376
 
334
377
  layers = tuple(range(1, depth + 1))
@@ -343,7 +386,18 @@ class MemoryAsContextTransformer(Module):
343
386
  for layer in layers:
344
387
  is_first = layer == 1
345
388
 
346
- # neural memory
389
+ # attention and feedforward
390
+
391
+ attn = SegmentedAttention(
392
+ dim = dim,
393
+ dim_head = dim_head,
394
+ heads = heads,
395
+ segment_len = segment_len,
396
+ use_flex_attn = use_flex_attn,
397
+ accept_value_residual = not is_first,
398
+ num_longterm_mem_tokens = num_longterm_mem_tokens,
399
+ num_persist_mem_tokens = num_persist_mem_tokens
400
+ )
347
401
 
348
402
  mem = None
349
403
 
@@ -356,21 +410,11 @@ class MemoryAsContextTransformer(Module):
356
410
  **neural_memory_kwargs
357
411
  )
358
412
 
359
- mem = init_hyper_conn(dim = dim, branch = mem)
360
-
361
- self.neural_mem_layers.append(mem)
362
-
363
- # attention and feedforward
364
-
365
- attn = SegmentedAttention(
366
- dim = dim,
367
- dim_head = dim_head,
368
- heads = heads,
369
- segment_len = segment_len,
370
- use_flex_attn = use_flex_attn,
371
- accept_value_residual = not is_first,
372
- num_longterm_mem_tokens = num_longterm_mem_tokens,
373
- num_persist_mem_tokens = num_persist_mem_tokens
413
+ attn = NeuralMemoryGatingWrapper(
414
+ dim,
415
+ attn = attn,
416
+ neural_mem = mem,
417
+ gate_attn_output = neural_mem_gate_attn_output
374
418
  )
375
419
 
376
420
  ff = FeedForward(dim = dim, mult = ff_mult)
@@ -460,19 +504,17 @@ class MemoryAsContextTransformer(Module):
460
504
 
461
505
  x = self.expand_streams(x)
462
506
 
463
- for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
464
-
465
- if exists(maybe_neural_mem):
466
- x, aux_kv_loss = maybe_neural_mem(x, return_aux_kv_loss = True)
467
- kv_recon_losses = kv_recon_losses + aux_kv_loss
507
+ for attn, ff in self.layers:
468
508
 
469
- x, values = attn(
509
+ (x, values), maybe_mem_kv_aux_loss = attn(
470
510
  x,
471
511
  value_residual = value_residual,
472
512
  disable_flex_attn = disable_flex_attn,
473
513
  flex_attn_fn = flex_attn_fn
474
514
  )
475
515
 
516
+ kv_recon_losses = kv_recon_losses + maybe_mem_kv_aux_loss
517
+
476
518
  value_residual = default(value_residual, values)
477
519
 
478
520
  x = ff(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.56
3
+ Version: 0.0.58
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=kk8s8Q2WmbJxCVi8PcqSUyJBc8-CDAHrVjt6M0d_kFs,15323
4
+ titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
5
+ titans_pytorch-0.0.58.dist-info/METADATA,sha256=a-Y6MV_89D44HlB7eKpurh-sw5DDiS-pIVei3Uw_uGE,4457
6
+ titans_pytorch-0.0.58.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.58.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.58.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=_Vsco5YuR6uxouWcjFj-s-zPhrBcaapIzqoyi7qqY0Q,14245
4
- titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
5
- titans_pytorch-0.0.56.dist-info/METADATA,sha256=QlCmHqajHiaZTps0W9gKXIHE6dShZER3PqPoYi2zRe4,4457
6
- titans_pytorch-0.0.56.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.56.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.56.dist-info/RECORD,,