titans-pytorch 0.0.58__py3-none-any.whl → 0.0.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,8 @@
1
1
  from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
- MemoryAttention
4
+ MemoryAttention,
5
+ FactorizedMemoryMLP
5
6
  )
6
7
 
7
8
  from titans_pytorch.mac_transformer import (
@@ -24,8 +24,8 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
24
24
 
25
25
  def create_mac_mask(b, h, q_idx, kv_idx):
26
26
  is_persist_mem = kv_idx < persist_mem_len
27
- causal_mask = q_idx >= (kv_idx - is_persist_mem)
28
- block_diagonal = (q_idx // window_size) == ((kv_idx - is_persist_mem) // window_size)
27
+ causal_mask = q_idx >= (kv_idx - persist_mem_len)
28
+ block_diagonal = (q_idx // window_size) == ((kv_idx - persist_mem_len) // window_size)
29
29
  return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
30
30
 
31
31
  block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
@@ -489,7 +489,7 @@ class MemoryAsContextTransformer(Module):
489
489
  flex_attn_fn = None
490
490
 
491
491
  if use_flex_attn:
492
- block_mask = create_mac_block_mask(seq_len_with_mem, self.segment_len, self.num_persist_mem_tokens)
492
+ block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens)
493
493
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
494
494
 
495
495
  # value residual
titans_pytorch/titans.py CHANGED
@@ -6,7 +6,7 @@ from functools import partial
6
6
  import torch
7
7
  from torch import nn, Tensor
8
8
  import torch.nn.functional as F
9
- from torch.nn import Linear, Module
9
+ from torch.nn import Linear, Module, Parameter, ParameterList
10
10
  from torch.func import functional_call, vmap, grad
11
11
 
12
12
  from tensordict import TensorDict
@@ -88,7 +88,7 @@ class MultiheadRMSNorm(Module):
88
88
  def __init__(self, dim, heads):
89
89
  super().__init__()
90
90
  self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
91
- self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
91
+ self.gamma = Parameter(torch.zeros(heads, 1, dim))
92
92
 
93
93
  def forward(self, x):
94
94
  return self.rmsnorm(x) * (self.gamma + 1.)
@@ -102,7 +102,10 @@ class MemoryMLP(Module):
102
102
  depth
103
103
  ):
104
104
  super().__init__()
105
- self.weights = nn.ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(depth)])
105
+ self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
106
+
107
+ for weight in self.weights:
108
+ nn.init.xavier_uniform_(weight)
106
109
 
107
110
  def forward(
108
111
  self,
@@ -118,13 +121,50 @@ class MemoryMLP(Module):
118
121
 
119
122
  return x
120
123
 
124
+ # memory mlp with factorized weights
125
+ # so can tradeoff capacity for smaller chunk sizes
126
+
127
+ class FactorizedMemoryMLP(Module):
128
+ def __init__(
129
+ self,
130
+ dim,
131
+ depth,
132
+ k = 32
133
+ ):
134
+ super().__init__()
135
+ self.weights = ParameterList([
136
+ ParameterList([
137
+ Parameter(torch.randn(dim, k)),
138
+ Parameter(torch.randn(k, dim)),
139
+ ]) for _ in range(depth)
140
+ ])
141
+
142
+ for weight1, weight2 in self.weights:
143
+ nn.init.xavier_uniform_(weight1)
144
+ nn.init.xavier_uniform_(weight2)
145
+
146
+ def forward(
147
+ self,
148
+ x
149
+ ):
150
+ for ind, (weight1, weight2) in enumerate(self.weights):
151
+ is_first = ind == 0
152
+
153
+ if not is_first:
154
+ x = F.silu(x)
155
+
156
+ x = x @ weight1 @ weight2
157
+
158
+ return x
159
+
121
160
  # improvised attention as memory module
122
161
 
123
162
  class MemoryAttention(Module):
124
163
  def __init__(
125
164
  self,
126
165
  dim,
127
- scale = 8.
166
+ scale = 8.,
167
+ expansion_factor = 2.
128
168
  ):
129
169
  super().__init__()
130
170
  self.scale = scale
@@ -133,10 +173,13 @@ class MemoryAttention(Module):
133
173
  nn.Parameter(torch.randn(dim, dim)), # queries
134
174
  nn.Parameter(torch.randn(dim, dim)), # keys
135
175
  nn.Parameter(torch.randn(dim, dim)), # values
136
- nn.Parameter(torch.randn(dim, dim * 2)), # ff w1
137
- nn.Parameter(torch.randn(dim * 2, dim)), # ff w2
176
+ nn.Parameter(torch.randn(dim, dim * expansion_factor)), # ff w1
177
+ nn.Parameter(torch.randn(dim * expansion_factor, dim)), # ff w2
138
178
  ])
139
179
 
180
+ for weight in self.weights:
181
+ nn.init.xavier_uniform_(weight)
182
+
140
183
  def forward(self, x):
141
184
  wq, wk, wv, ffw1, ffw2 = self.weights
142
185
 
@@ -289,6 +332,8 @@ class NeuralMemory(Module):
289
332
 
290
333
  self.use_accelerated_scan = use_accelerated_scan
291
334
 
335
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
336
+
292
337
  def init_weights_and_momentum(self):
293
338
  params = TensorDict(dict(self.memory_model.named_parameters()))
294
339
 
@@ -306,6 +351,13 @@ class NeuralMemory(Module):
306
351
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
307
352
  return_aux_kv_loss = False
308
353
  ):
354
+ seq_len = seq.shape[-2]
355
+
356
+ # handle edge case
357
+
358
+ if seq_len < self.chunk_size:
359
+ past_weight, _ = past_state
360
+ return TensorDict(past_weight).clone().zero_(), self.zero
309
361
 
310
362
  seq = self.store_norm(seq)
311
363
 
@@ -425,12 +477,10 @@ class NeuralMemory(Module):
425
477
 
426
478
  last_update = updates.apply(lambda t: t[:, -1])
427
479
 
428
- next_state = (curr_weights + last_update, next_momentum)
429
-
430
480
  if not return_aux_kv_loss:
431
- return updates, next_state
481
+ return updates
432
482
 
433
- return updates, next_state, aux_kv_recon_loss.mean()
483
+ return updates, aux_kv_recon_loss.mean()
434
484
 
435
485
  def retrieve_memories(
436
486
  self,
@@ -442,7 +492,8 @@ class NeuralMemory(Module):
442
492
 
443
493
  seq = self.retrieve_norm(seq)
444
494
 
445
- assert seq_len >= chunk_size
495
+ if seq_len < self.chunk_size:
496
+ return self.init_empty_memory_embed(batch, seq_len)
446
497
 
447
498
  seq = seq[:, (chunk_size - 1):]
448
499
  curtailed_seq_len = seq.shape[-2]
@@ -524,10 +575,11 @@ class NeuralMemory(Module):
524
575
 
525
576
  store_seq = default(store_seq, seq)
526
577
 
527
- updates, next_memories, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
578
+ updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
528
579
 
529
580
  past_weights, _ = past_state
530
581
 
582
+
531
583
  retrieved = self.retrieve_memories(seq, past_weights + updates)
532
584
 
533
585
  if not return_aux_kv_loss:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.58
3
+ Version: 0.0.62
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=R9Xu-BjqAe9ZY60IGk4aNXBx_L8THsjJ4QrkbTnRNHo,15346
4
+ titans_pytorch/titans.py,sha256=95J6UL44lOrdZSXdm7p36xC9tDeSmRBwdjig9T82PzI,17452
5
+ titans_pytorch-0.0.62.dist-info/METADATA,sha256=08Blaa9Ehyv09rSA9uWguxbhKpbrd7C53Ya13E1VbpU,4457
6
+ titans_pytorch-0.0.62.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.62.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.62.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=kk8s8Q2WmbJxCVi8PcqSUyJBc8-CDAHrVjt6M0d_kFs,15323
4
- titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
5
- titans_pytorch-0.0.58.dist-info/METADATA,sha256=a-Y6MV_89D44HlB7eKpurh-sw5DDiS-pIVei3Uw_uGE,4457
6
- titans_pytorch-0.0.58.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.58.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.58.dist-info/RECORD,,