titans-pytorch 0.1.7__tar.gz → 0.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.7"
3
+ version = "0.1.8"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -53,6 +53,19 @@ def test_titans_attn_memory():
53
53
 
54
54
  assert seq.shape == retrieved.shape
55
55
 
56
+ def test_retrieve_store_diff_seq():
57
+ mem = NeuralMemory(
58
+ dim = 384,
59
+ chunk_size = (64, 32),
60
+ )
61
+
62
+ retrieve_seq = torch.randn(2, 64 * 64, 384)
63
+ store_seq = torch.randn(2, 64 * 32, 384)
64
+
65
+ retrieved = mem(retrieve_seq, store_seq = store_seq)
66
+
67
+ assert retrieve_seq.shape == retrieved.shape
68
+
56
69
  @pytest.mark.parametrize('seq_len', (1023, 17))
57
70
  @pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
58
71
  @pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
@@ -44,6 +44,9 @@ def default(v, d):
44
44
  def identity(t):
45
45
  return t
46
46
 
47
+ def pair(v):
48
+ return (v, v) if not isinstance(v, tuple) else v
49
+
47
50
  def round_down_multiple(seq, mult):
48
51
  return seq // mult * mult
49
52
 
@@ -290,7 +293,7 @@ class NeuralMemory(Module):
290
293
  def __init__(
291
294
  self,
292
295
  dim,
293
- chunk_size = 1,
296
+ chunk_size: int | tuple[int, int] = 1,
294
297
  dim_head = None,
295
298
  heads = 1,
296
299
  model: Module | None = None,
@@ -313,6 +316,8 @@ class NeuralMemory(Module):
313
316
  super().__init__()
314
317
  dim_head = default(dim_head, dim)
315
318
 
319
+ self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
320
+
316
321
  # norms
317
322
 
318
323
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
@@ -380,6 +385,10 @@ class NeuralMemory(Module):
380
385
  self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
381
386
  nn.init.normal_(self.empty_memory_embed, std = 0.02)
382
387
 
388
+ # `chunk_size` refers to chunk size used for storing to memory model weights
389
+
390
+ chunk_size = self.store_chunk_size
391
+
383
392
  # whether to use averaging of chunks, or attention pooling
384
393
 
385
394
  if not attn_pool_chunks:
@@ -451,11 +460,11 @@ class NeuralMemory(Module):
451
460
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
452
461
  return_aux_kv_loss = False
453
462
  ):
454
- seq_len = seq.shape[-2]
463
+ seq_len, chunk_size = seq.shape[-2], self.store_chunk_size
455
464
 
456
465
  # handle edge case
457
466
 
458
- if seq_len < self.chunk_size:
467
+ if seq_len < chunk_size:
459
468
  past_weight, _ = past_state
460
469
  return TensorDict(past_weight).clone().zero_(), self.zero
461
470
 
@@ -464,8 +473,7 @@ class NeuralMemory(Module):
464
473
  # curtail sequence by multiple of the chunk size
465
474
  # only a complete chunk of the sequence provides the memory for the next chunk
466
475
 
467
- seq_len, chunk_size = seq.shape[-2], self.chunk_size
468
- round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
476
+ round_down_seq_len = round_down_multiple(seq_len, chunk_size)
469
477
 
470
478
  seq = seq[:, :round_down_seq_len]
471
479
 
@@ -597,12 +605,12 @@ class NeuralMemory(Module):
597
605
  seq,
598
606
  past_weights: dict[str, Tensor] | None = None,
599
607
  ):
600
- chunk_size = self.chunk_size
608
+ chunk_size = self.retrieve_chunk_size
601
609
  batch, seq_len = seq.shape[:2]
602
610
 
603
611
  seq = self.retrieve_norm(seq)
604
612
 
605
- if seq_len < self.chunk_size:
613
+ if seq_len < chunk_size:
606
614
  return self.init_empty_memory_embed(batch, seq_len)
607
615
 
608
616
  seq = seq[:, (chunk_size - 1):]
@@ -674,7 +682,7 @@ class NeuralMemory(Module):
674
682
  ):
675
683
  batch, seq_len = seq.shape[:2]
676
684
 
677
- if seq_len < self.chunk_size:
685
+ if seq_len < self.retrieve_chunk_size:
678
686
  out = self.init_empty_memory_embed(batch, seq_len)
679
687
 
680
688
  if not return_aux_kv_loss:
File without changes
File without changes
File without changes
File without changes