titans-pytorch 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -44,6 +44,9 @@ def default(v, d):
44
44
  def identity(t):
45
45
  return t
46
46
 
47
+ def pair(v):
48
+ return (v, v) if not isinstance(v, tuple) else v
49
+
47
50
  def round_down_multiple(seq, mult):
48
51
  return seq // mult * mult
49
52
 
@@ -161,14 +164,16 @@ class GatedResidualMemoryMLP(Module):
161
164
  def __init__(
162
165
  self,
163
166
  dim,
164
- depth
167
+ depth,
168
+ expansion_factor = 2.
165
169
  ):
166
170
  super().__init__()
167
- self.depth = depth
171
+ dim_hidden = int(dim * expansion_factor)
168
172
 
169
173
  self.weights = ParameterList([
170
174
  ParameterList([
171
- Parameter(torch.randn(dim, dim)),
175
+ Parameter(torch.randn(dim, dim_hidden)),
176
+ Parameter(torch.randn(dim_hidden, dim)),
172
177
  Parameter(torch.randn(dim * 2, dim)),
173
178
  ]) for _ in range(depth)
174
179
  ])
@@ -182,16 +187,17 @@ class GatedResidualMemoryMLP(Module):
182
187
  self,
183
188
  x
184
189
  ):
185
- for weight, to_gates in self.weights:
190
+ for weight1, weight2, to_gates in self.weights:
186
191
  res = x
187
192
 
188
- x = x @ weight
189
- x = F.silu(x)
193
+ hidden = x @ weight1
194
+ hidden = F.silu(hidden)
195
+ branch_out = hidden @ weight2
190
196
 
191
197
  # gated residual
192
198
 
193
- gates = cat((x, res), dim = -1) @ to_gates
194
- x = res.lerp(x, gates.sigmoid())
199
+ gates = cat((branch_out, res), dim = -1) @ to_gates
200
+ x = res.lerp(branch_out, gates.sigmoid())
195
201
 
196
202
  return x @ self.final_proj
197
203
 
@@ -287,7 +293,7 @@ class NeuralMemory(Module):
287
293
  def __init__(
288
294
  self,
289
295
  dim,
290
- chunk_size = 1,
296
+ chunk_size: int | tuple[int, int] = 1,
291
297
  dim_head = None,
292
298
  heads = 1,
293
299
  model: Module | None = None,
@@ -310,6 +316,8 @@ class NeuralMemory(Module):
310
316
  super().__init__()
311
317
  dim_head = default(dim_head, dim)
312
318
 
319
+ self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
320
+
313
321
  # norms
314
322
 
315
323
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
@@ -377,6 +385,10 @@ class NeuralMemory(Module):
377
385
  self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
378
386
  nn.init.normal_(self.empty_memory_embed, std = 0.02)
379
387
 
388
+ # `chunk_size` refers to chunk size used for storing to memory model weights
389
+
390
+ chunk_size = self.store_chunk_size
391
+
380
392
  # whether to use averaging of chunks, or attention pooling
381
393
 
382
394
  if not attn_pool_chunks:
@@ -448,11 +460,11 @@ class NeuralMemory(Module):
448
460
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
449
461
  return_aux_kv_loss = False
450
462
  ):
451
- seq_len = seq.shape[-2]
463
+ seq_len, chunk_size = seq.shape[-2], self.store_chunk_size
452
464
 
453
465
  # handle edge case
454
466
 
455
- if seq_len < self.chunk_size:
467
+ if seq_len < chunk_size:
456
468
  past_weight, _ = past_state
457
469
  return TensorDict(past_weight).clone().zero_(), self.zero
458
470
 
@@ -461,8 +473,7 @@ class NeuralMemory(Module):
461
473
  # curtail sequence by multiple of the chunk size
462
474
  # only a complete chunk of the sequence provides the memory for the next chunk
463
475
 
464
- seq_len, chunk_size = seq.shape[-2], self.chunk_size
465
- round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
476
+ round_down_seq_len = round_down_multiple(seq_len, chunk_size)
466
477
 
467
478
  seq = seq[:, :round_down_seq_len]
468
479
 
@@ -594,12 +605,12 @@ class NeuralMemory(Module):
594
605
  seq,
595
606
  past_weights: dict[str, Tensor] | None = None,
596
607
  ):
597
- chunk_size = self.chunk_size
608
+ chunk_size = self.retrieve_chunk_size
598
609
  batch, seq_len = seq.shape[:2]
599
610
 
600
611
  seq = self.retrieve_norm(seq)
601
612
 
602
- if seq_len < self.chunk_size:
613
+ if seq_len < chunk_size:
603
614
  return self.init_empty_memory_embed(batch, seq_len)
604
615
 
605
616
  seq = seq[:, (chunk_size - 1):]
@@ -671,7 +682,7 @@ class NeuralMemory(Module):
671
682
  ):
672
683
  batch, seq_len = seq.shape[:2]
673
684
 
674
- if seq_len < self.chunk_size:
685
+ if seq_len < self.retrieve_chunk_size:
675
686
  out = self.init_empty_memory_embed(batch, seq_len)
676
687
 
677
688
  if not return_aux_kv_loss:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
4
+ titans_pytorch/titans.py,sha256=qRUw-Lad_dkMqV7ASMNoGLgxYwGD-maAadetAd_qmc8,21031
5
+ titans_pytorch-0.1.8.dist-info/METADATA,sha256=0-m6h7GERineU8N9_2cW6nCuXs96twFEwVYkHVuuuLM,4747
6
+ titans_pytorch-0.1.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.8.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
4
- titans_pytorch/titans.py,sha256=VMcPcKsoR3G13Um62Aa1HbdwrrV60ljPhP-yF40x90I,20555
5
- titans_pytorch-0.1.6.dist-info/METADATA,sha256=LJW26WfT9WB-0NfokLLHhcRpWnt76jwkXMt_FSTI3SM,4747
6
- titans_pytorch-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.6.dist-info/RECORD,,