titans-pytorch 0.0.58__py3-none-any.whl → 0.0.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,8 +24,8 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
24
24
 
25
25
  def create_mac_mask(b, h, q_idx, kv_idx):
26
26
  is_persist_mem = kv_idx < persist_mem_len
27
- causal_mask = q_idx >= (kv_idx - is_persist_mem)
28
- block_diagonal = (q_idx // window_size) == ((kv_idx - is_persist_mem) // window_size)
27
+ causal_mask = q_idx >= (kv_idx - persist_mem_len)
28
+ block_diagonal = (q_idx // window_size) == ((kv_idx - persist_mem_len) // window_size)
29
29
  return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
30
30
 
31
31
  block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
@@ -489,7 +489,7 @@ class MemoryAsContextTransformer(Module):
489
489
  flex_attn_fn = None
490
490
 
491
491
  if use_flex_attn:
492
- block_mask = create_mac_block_mask(seq_len_with_mem, self.segment_len, self.num_persist_mem_tokens)
492
+ block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens)
493
493
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
494
494
 
495
495
  # value residual
titans_pytorch/titans.py CHANGED
@@ -289,6 +289,8 @@ class NeuralMemory(Module):
289
289
 
290
290
  self.use_accelerated_scan = use_accelerated_scan
291
291
 
292
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
293
+
292
294
  def init_weights_and_momentum(self):
293
295
  params = TensorDict(dict(self.memory_model.named_parameters()))
294
296
 
@@ -306,6 +308,13 @@ class NeuralMemory(Module):
306
308
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
307
309
  return_aux_kv_loss = False
308
310
  ):
311
+ seq_len = seq.shape[-2]
312
+
313
+ # handle edge case
314
+
315
+ if seq_len < self.chunk_size:
316
+ past_weight, _ = past_state
317
+ return TensorDict(past_weight).clone().zero_(), self.zero
309
318
 
310
319
  seq = self.store_norm(seq)
311
320
 
@@ -425,12 +434,10 @@ class NeuralMemory(Module):
425
434
 
426
435
  last_update = updates.apply(lambda t: t[:, -1])
427
436
 
428
- next_state = (curr_weights + last_update, next_momentum)
429
-
430
437
  if not return_aux_kv_loss:
431
- return updates, next_state
438
+ return updates
432
439
 
433
- return updates, next_state, aux_kv_recon_loss.mean()
440
+ return updates, aux_kv_recon_loss.mean()
434
441
 
435
442
  def retrieve_memories(
436
443
  self,
@@ -442,7 +449,8 @@ class NeuralMemory(Module):
442
449
 
443
450
  seq = self.retrieve_norm(seq)
444
451
 
445
- assert seq_len >= chunk_size
452
+ if seq_len < self.chunk_size:
453
+ return self.init_empty_memory_embed(batch, seq_len)
446
454
 
447
455
  seq = seq[:, (chunk_size - 1):]
448
456
  curtailed_seq_len = seq.shape[-2]
@@ -524,7 +532,7 @@ class NeuralMemory(Module):
524
532
 
525
533
  store_seq = default(store_seq, seq)
526
534
 
527
- updates, next_memories, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
535
+ updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
528
536
 
529
537
  past_weights, _ = past_state
530
538
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.58
3
+ Version: 0.0.61
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=R9Xu-BjqAe9ZY60IGk4aNXBx_L8THsjJ4QrkbTnRNHo,15346
4
+ titans_pytorch/titans.py,sha256=5wuAoDULbgXTM8Nbq8bXrW3Fd2nsn22kpERRfJOwZiU,16367
5
+ titans_pytorch-0.0.61.dist-info/METADATA,sha256=Cfhqnse_9nnFNqVGo9p_kxO_LVawwv4uuZOx4anqhf0,4457
6
+ titans_pytorch-0.0.61.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.61.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.61.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=kk8s8Q2WmbJxCVi8PcqSUyJBc8-CDAHrVjt6M0d_kFs,15323
4
- titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
5
- titans_pytorch-0.0.58.dist-info/METADATA,sha256=a-Y6MV_89D44HlB7eKpurh-sw5DDiS-pIVei3Uw_uGE,4457
6
- titans_pytorch-0.0.58.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.58.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.58.dist-info/RECORD,,