titans-pytorch 0.0.57__py3-none-any.whl → 0.0.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,8 +24,8 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
24
24
 
25
25
  def create_mac_mask(b, h, q_idx, kv_idx):
26
26
  is_persist_mem = kv_idx < persist_mem_len
27
- causal_mask = q_idx >= (kv_idx - is_persist_mem)
28
- block_diagonal = (q_idx // window_size) == ((kv_idx - is_persist_mem) // window_size)
27
+ causal_mask = q_idx >= (kv_idx - persist_mem_len)
28
+ block_diagonal = (q_idx // window_size) == ((kv_idx - persist_mem_len) // window_size)
29
29
  return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
30
30
 
31
31
  block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
@@ -292,12 +292,13 @@ class NeuralMemoryGatingWrapper(Module):
292
292
  self,
293
293
  dim,
294
294
  attn: SegmentedAttention,
295
- neural_mem: NeuralMemory | None = None
295
+ neural_mem: NeuralMemory | None = None,
296
+ gate_attn_output = True
296
297
  ):
297
298
  super().__init__()
298
299
  self.attn = attn
299
300
  self.neural_mem = neural_mem
300
- self.to_gates = nn.Linear(dim, dim) if exists(neural_mem) else None
301
+ self.gate_attn_output = gate_attn_output
301
302
 
302
303
  def forward(
303
304
  self,
@@ -313,21 +314,19 @@ class NeuralMemoryGatingWrapper(Module):
313
314
 
314
315
  # initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory
315
316
 
316
- retrieved, first_kv_aux_loss = mem(seq, return_aux_kv_loss = True)
317
+ retrieved, kv_aux_loss = mem(seq, return_aux_kv_loss = True)
317
318
 
318
- seq = seq + retrieved
319
+ if not self.gate_attn_output:
320
+ seq = seq + retrieved
319
321
 
320
322
  # attention
321
323
 
322
324
  attn_out, values = self.attn(seq, *args, **kwargs)
323
325
 
324
- # another retrieve, but this time gate the attention output
326
+ if self.gate_attn_output:
327
+ attn_out = attn_out * retrieved.sigmoid()
325
328
 
326
- retrieved, second_kv_aux_loss = mem(attn_out, return_aux_kv_loss = True)
327
-
328
- attn_out = attn_out * self.to_gates(retrieved).sigmoid()
329
-
330
- return (attn_out, values), (first_kv_aux_loss + second_kv_aux_loss)
329
+ return (attn_out, values), kv_aux_loss
331
330
 
332
331
  # MAC transformer
333
332
 
@@ -340,6 +339,7 @@ class MemoryAsContextTransformer(Module):
340
339
  depth,
341
340
  segment_len,
342
341
  neural_memory_segment_len = None,
342
+ neural_mem_gate_attn_output = True,
343
343
  num_longterm_mem_tokens = 0,
344
344
  num_persist_mem_tokens = 0,
345
345
  dim_head = 64,
@@ -414,6 +414,7 @@ class MemoryAsContextTransformer(Module):
414
414
  dim,
415
415
  attn = attn,
416
416
  neural_mem = mem,
417
+ gate_attn_output = neural_mem_gate_attn_output
417
418
  )
418
419
 
419
420
  ff = FeedForward(dim = dim, mult = ff_mult)
@@ -488,7 +489,7 @@ class MemoryAsContextTransformer(Module):
488
489
  flex_attn_fn = None
489
490
 
490
491
  if use_flex_attn:
491
- block_mask = create_mac_block_mask(seq_len_with_mem, self.segment_len, self.num_persist_mem_tokens)
492
+ block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens)
492
493
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
493
494
 
494
495
  # value residual
titans_pytorch/titans.py CHANGED
@@ -289,6 +289,8 @@ class NeuralMemory(Module):
289
289
 
290
290
  self.use_accelerated_scan = use_accelerated_scan
291
291
 
292
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
293
+
292
294
  def init_weights_and_momentum(self):
293
295
  params = TensorDict(dict(self.memory_model.named_parameters()))
294
296
 
@@ -306,6 +308,13 @@ class NeuralMemory(Module):
306
308
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
307
309
  return_aux_kv_loss = False
308
310
  ):
311
+ seq_len = seq.shape[-2]
312
+
313
+ # handle edge case
314
+
315
+ if seq_len < self.chunk_size:
316
+ past_weight, _ = past_state
317
+ return TensorDict(past_weight).clone().zero_(), self.zero
309
318
 
310
319
  seq = self.store_norm(seq)
311
320
 
@@ -425,12 +434,10 @@ class NeuralMemory(Module):
425
434
 
426
435
  last_update = updates.apply(lambda t: t[:, -1])
427
436
 
428
- next_state = (curr_weights + last_update, next_momentum)
429
-
430
437
  if not return_aux_kv_loss:
431
- return updates, next_state
438
+ return updates
432
439
 
433
- return updates, next_state, aux_kv_recon_loss.mean()
440
+ return updates, aux_kv_recon_loss.mean()
434
441
 
435
442
  def retrieve_memories(
436
443
  self,
@@ -442,7 +449,8 @@ class NeuralMemory(Module):
442
449
 
443
450
  seq = self.retrieve_norm(seq)
444
451
 
445
- assert seq_len >= chunk_size
452
+ if seq_len < self.chunk_size:
453
+ return self.init_empty_memory_embed(batch, seq_len)
446
454
 
447
455
  seq = seq[:, (chunk_size - 1):]
448
456
  curtailed_seq_len = seq.shape[-2]
@@ -524,7 +532,7 @@ class NeuralMemory(Module):
524
532
 
525
533
  store_seq = default(store_seq, seq)
526
534
 
527
- updates, next_memories, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
535
+ updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
528
536
 
529
537
  past_weights, _ = past_state
530
538
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.57
3
+ Version: 0.0.61
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=R9Xu-BjqAe9ZY60IGk4aNXBx_L8THsjJ4QrkbTnRNHo,15346
4
+ titans_pytorch/titans.py,sha256=5wuAoDULbgXTM8Nbq8bXrW3Fd2nsn22kpERRfJOwZiU,16367
5
+ titans_pytorch-0.0.61.dist-info/METADATA,sha256=Cfhqnse_9nnFNqVGo9p_kxO_LVawwv4uuZOx4anqhf0,4457
6
+ titans_pytorch-0.0.61.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.61.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.61.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=-WS_SI--_5f1whMaJOH-mYCU37EjYU_iZTurGfs8zgI,15331
4
- titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
5
- titans_pytorch-0.0.57.dist-info/METADATA,sha256=rwLIRndtBo22oJt0Xm9xK9zqOYV50Jfo6g7oVrKq7CQ,4457
6
- titans_pytorch-0.0.57.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.57.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.57.dist-info/RECORD,,