titans-pytorch 0.1.17__tar.gz → 0.1.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.17
3
+ Version: 0.1.18
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.9
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
- Requires-Dist: hyper-connections>=0.1.8
41
+ Requires-Dist: hyper-connections>=0.1.9
42
42
  Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.17"
3
+ version = "0.1.18"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -29,7 +29,7 @@ dependencies = [
29
29
  "axial_positional_embedding>=0.3.9",
30
30
  "einops>=0.8.0",
31
31
  "einx>=0.3.0",
32
- "hyper-connections>=0.1.8",
32
+ "hyper-connections>=0.1.9",
33
33
  "Ninja",
34
34
  "rotary-embedding-torch",
35
35
  "tensordict",
@@ -217,7 +217,8 @@ class SegmentedAttention(Module):
217
217
  self,
218
218
  seq,
219
219
  value_residual = None,
220
- flex_attn_fn: Callable | None = None
220
+ flex_attn_fn: Callable | None = None,
221
+ output_gating = None
221
222
  ):
222
223
 
223
224
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
@@ -267,6 +268,9 @@ class SegmentedAttention(Module):
267
268
 
268
269
  out = self.to_out(out)
269
270
 
271
+ if exists(output_gating):
272
+ out = out * output_gating
273
+
270
274
  return out, orig_v
271
275
 
272
276
  def forward(
@@ -274,10 +278,11 @@ class SegmentedAttention(Module):
274
278
  seq,
275
279
  value_residual = None,
276
280
  flex_attn_fn: Callable | None = None,
277
- disable_flex_attn = False
281
+ disable_flex_attn = False,
282
+ output_gating = None
278
283
  ):
279
284
  if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
280
- return self.forward_flex(seq, value_residual, flex_attn_fn)
285
+ return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating)
281
286
 
282
287
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
283
288
 
@@ -361,50 +366,10 @@ class SegmentedAttention(Module):
361
366
 
362
367
  out = inverse_segment(out)
363
368
 
364
- return out, orig_v
365
-
366
- # Attention + Neural Memory gating configuration, as depicted in Figure 2
367
-
368
- class NeuralMemoryGatingWrapper(Module):
369
- def __init__(
370
- self,
371
- dim,
372
- attn: SegmentedAttention,
373
- neural_mem: NeuralMemory | None = None,
374
- gate_attn_output = True
375
- ):
376
- super().__init__()
377
- self.attn = attn
378
- self.neural_mem = neural_mem
379
- self.gate_attn_output = gate_attn_output
380
-
381
- def forward(
382
- self,
383
- seq,
384
- *args,
385
- **kwargs
386
- ):
387
- batch, seq_len = seq.shape[:2]
388
- mem = self.neural_mem
389
-
390
- if not exists(mem):
391
- return self.attn(seq, *args, **kwargs), 0.
369
+ if exists(output_gating):
370
+ out = out * output_gating
392
371
 
393
- # initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory
394
-
395
- retrieved, kv_aux_loss = mem(seq, return_aux_kv_loss = True)
396
-
397
- if not self.gate_attn_output:
398
- seq = seq + retrieved
399
-
400
- # attention
401
-
402
- attn_out, values = self.attn(seq, *args, **kwargs)
403
-
404
- if self.gate_attn_output:
405
- attn_out = attn_out * retrieved.sigmoid()
406
-
407
- return (attn_out, values), kv_aux_loss
372
+ return out, orig_v
408
373
 
409
374
  # MAC transformer
410
375
 
@@ -494,16 +459,10 @@ class MemoryAsContextTransformer(Module):
494
459
  **neural_memory_kwargs
495
460
  )
496
461
 
497
- attn = NeuralMemoryGatingWrapper(
498
- dim,
499
- attn = attn,
500
- neural_mem = mem,
501
- gate_attn_output = neural_mem_gate_attn_output
502
- )
503
-
504
462
  ff = FeedForward(dim = dim, mult = ff_mult)
505
463
 
506
464
  self.layers.append(ModuleList([
465
+ init_hyper_conn(dim = dim, branch = mem, add_branch_out_to_residual = not neural_mem_gate_attn_output) if exists(mem) else None,
507
466
  init_hyper_conn(dim = dim, branch = attn),
508
467
  init_hyper_conn(dim = dim, branch = ff)
509
468
  ]))
@@ -512,6 +471,10 @@ class MemoryAsContextTransformer(Module):
512
471
 
513
472
  self.to_logits = LinearNoBias(dim, num_tokens)
514
473
 
474
+ # whether to gate the attention output with the retrieved memories
475
+
476
+ self.gate_attn_output = neural_mem_gate_attn_output
477
+
515
478
  # auxiliary loss on kv recon
516
479
 
517
480
  self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
@@ -652,19 +615,34 @@ class MemoryAsContextTransformer(Module):
652
615
 
653
616
  x = self.expand_streams(x)
654
617
 
655
- for attn, ff in self.layers:
618
+ for mem, attn, ff in self.layers:
619
+
620
+ retrieved = None
621
+ attn_out_gates = None
622
+
623
+ if exists(mem):
624
+ retrieved, mem_kv_aux_loss = mem(x, return_aux_kv_loss = True)
625
+ kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
656
626
 
657
- (x, values), maybe_mem_kv_aux_loss = attn(
627
+ if self.gate_attn_output:
628
+ attn_out_gates = retrieved.sigmoid()
629
+ else:
630
+ seq = retrieved
631
+
632
+ # attention
633
+
634
+ x, values = attn(
658
635
  x,
659
636
  value_residual = value_residual,
660
637
  disable_flex_attn = disable_flex_attn,
661
- flex_attn_fn = flex_attn_fn
638
+ flex_attn_fn = flex_attn_fn,
639
+ output_gating = attn_out_gates
662
640
  )
663
641
 
664
- kv_recon_losses = kv_recon_losses + maybe_mem_kv_aux_loss
665
-
666
642
  value_residual = default(value_residual, values)
667
643
 
644
+ # feedforward
645
+
668
646
  x = ff(x)
669
647
 
670
648
  x = self.reduce_streams(x)
File without changes