titans-pytorch 0.1.7__tar.gz → 0.1.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.7
3
+ Version: 0.1.9
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.7"
3
+ version = "0.1.9"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -53,6 +53,19 @@ def test_titans_attn_memory():
53
53
 
54
54
  assert seq.shape == retrieved.shape
55
55
 
56
+ def test_retrieve_store_diff_seq():
57
+ mem = NeuralMemory(
58
+ dim = 384,
59
+ chunk_size = (64, 32),
60
+ )
61
+
62
+ retrieve_seq = torch.randn(2, 64 * 64, 384)
63
+ store_seq = torch.randn(2, 64 * 32, 384)
64
+
65
+ retrieved = mem(retrieve_seq, store_seq = store_seq)
66
+
67
+ assert retrieve_seq.shape == retrieved.shape
68
+
56
69
  @pytest.mark.parametrize('seq_len', (1023, 17))
57
70
  @pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
58
71
  @pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
@@ -83,6 +83,14 @@ def identity(t):
83
83
  def round_up_multiple(seq, mult):
84
84
  return ceil(seq / mult) * mult
85
85
 
86
+ def pack_with_inverse(t, pattern):
87
+ packed, packed_shape = pack(t, pattern)
88
+
89
+ def inverse(out, inv_pattern = None):
90
+ return unpack(out, packed_shape, default(inv_pattern, pattern))
91
+
92
+ return packed, inverse
93
+
86
94
  def pad_at_dim(t, pad, dim = -1, value = 0.):
87
95
  dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
88
96
  zeros = ((0, 0) * dims_from_right)
@@ -576,7 +584,7 @@ class MemoryAsContextTransformer(Module):
576
584
  x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
577
585
 
578
586
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
579
- x, mem_ps = pack((x, mems), 'b * d')
587
+ x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
580
588
 
581
589
  x = inverse_segment(x)
582
590
 
@@ -634,7 +642,7 @@ class MemoryAsContextTransformer(Module):
634
642
 
635
643
  x, inverse_segment = pad_and_segment_with_inverse(x, segment_len + num_longterm_mem_tokens)
636
644
 
637
- x, _ = unpack(x, mem_ps, 'b * d')
645
+ x, _ = inverse_pack_mems(x)
638
646
 
639
647
  x = inverse_segment(x)
640
648
 
@@ -44,6 +44,9 @@ def default(v, d):
44
44
  def identity(t):
45
45
  return t
46
46
 
47
+ def pair(v):
48
+ return (v, v) if not isinstance(v, tuple) else v
49
+
47
50
  def round_down_multiple(seq, mult):
48
51
  return seq // mult * mult
49
52
 
@@ -290,7 +293,7 @@ class NeuralMemory(Module):
290
293
  def __init__(
291
294
  self,
292
295
  dim,
293
- chunk_size = 1,
296
+ chunk_size: int | tuple[int, int] = 1,
294
297
  dim_head = None,
295
298
  heads = 1,
296
299
  model: Module | None = None,
@@ -313,6 +316,8 @@ class NeuralMemory(Module):
313
316
  super().__init__()
314
317
  dim_head = default(dim_head, dim)
315
318
 
319
+ self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
320
+
316
321
  # norms
317
322
 
318
323
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
@@ -380,8 +385,14 @@ class NeuralMemory(Module):
380
385
  self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
381
386
  nn.init.normal_(self.empty_memory_embed, std = 0.02)
382
387
 
388
+ # `chunk_size` refers to chunk size used for storing to memory model weights
389
+
390
+ chunk_size = self.store_chunk_size
391
+
383
392
  # whether to use averaging of chunks, or attention pooling
384
393
 
394
+ assert not (attn_pool_chunks and chunk_size == 1), '`attn_pool_chunks` cannot be set to True if `chunk_size` is set to 1'
395
+
385
396
  if not attn_pool_chunks:
386
397
  chunk_reduce_module = Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size)
387
398
  else:
@@ -451,11 +462,11 @@ class NeuralMemory(Module):
451
462
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
452
463
  return_aux_kv_loss = False
453
464
  ):
454
- seq_len = seq.shape[-2]
465
+ seq_len, chunk_size = seq.shape[-2], self.store_chunk_size
455
466
 
456
467
  # handle edge case
457
468
 
458
- if seq_len < self.chunk_size:
469
+ if seq_len < chunk_size:
459
470
  past_weight, _ = past_state
460
471
  return TensorDict(past_weight).clone().zero_(), self.zero
461
472
 
@@ -464,8 +475,7 @@ class NeuralMemory(Module):
464
475
  # curtail sequence by multiple of the chunk size
465
476
  # only a complete chunk of the sequence provides the memory for the next chunk
466
477
 
467
- seq_len, chunk_size = seq.shape[-2], self.chunk_size
468
- round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
478
+ round_down_seq_len = round_down_multiple(seq_len, chunk_size)
469
479
 
470
480
  seq = seq[:, :round_down_seq_len]
471
481
 
@@ -597,12 +607,12 @@ class NeuralMemory(Module):
597
607
  seq,
598
608
  past_weights: dict[str, Tensor] | None = None,
599
609
  ):
600
- chunk_size = self.chunk_size
610
+ chunk_size = self.retrieve_chunk_size
601
611
  batch, seq_len = seq.shape[:2]
602
612
 
603
613
  seq = self.retrieve_norm(seq)
604
614
 
605
- if seq_len < self.chunk_size:
615
+ if seq_len < chunk_size:
606
616
  return self.init_empty_memory_embed(batch, seq_len)
607
617
 
608
618
  seq = seq[:, (chunk_size - 1):]
@@ -674,7 +684,7 @@ class NeuralMemory(Module):
674
684
  ):
675
685
  batch, seq_len = seq.shape[:2]
676
686
 
677
- if seq_len < self.chunk_size:
687
+ if seq_len < self.retrieve_chunk_size:
678
688
  out = self.init_empty_memory_embed(batch, seq_len)
679
689
 
680
690
  if not return_aux_kv_loss:
File without changes
File without changes
File without changes
File without changes