titans-pytorch 0.0.58__py3-none-any.whl → 0.0.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +3 -3
- titans_pytorch/titans.py +14 -6
- {titans_pytorch-0.0.58.dist-info → titans_pytorch-0.0.61.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.61.dist-info/RECORD +8 -0
- titans_pytorch-0.0.58.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.58.dist-info → titans_pytorch-0.0.61.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.58.dist-info → titans_pytorch-0.0.61.dist-info}/licenses/LICENSE +0 -0
|
@@ -24,8 +24,8 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
|
|
|
24
24
|
|
|
25
25
|
def create_mac_mask(b, h, q_idx, kv_idx):
|
|
26
26
|
is_persist_mem = kv_idx < persist_mem_len
|
|
27
|
-
causal_mask = q_idx >= (kv_idx -
|
|
28
|
-
block_diagonal = (q_idx // window_size) == ((kv_idx -
|
|
27
|
+
causal_mask = q_idx >= (kv_idx - persist_mem_len)
|
|
28
|
+
block_diagonal = (q_idx // window_size) == ((kv_idx - persist_mem_len) // window_size)
|
|
29
29
|
return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
|
|
30
30
|
|
|
31
31
|
block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
|
|
@@ -489,7 +489,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
489
489
|
flex_attn_fn = None
|
|
490
490
|
|
|
491
491
|
if use_flex_attn:
|
|
492
|
-
block_mask = create_mac_block_mask(seq_len_with_mem,
|
|
492
|
+
block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens)
|
|
493
493
|
flex_attn_fn = partial(flex_attention, block_mask = block_mask)
|
|
494
494
|
|
|
495
495
|
# value residual
|
titans_pytorch/titans.py
CHANGED
|
@@ -289,6 +289,8 @@ class NeuralMemory(Module):
|
|
|
289
289
|
|
|
290
290
|
self.use_accelerated_scan = use_accelerated_scan
|
|
291
291
|
|
|
292
|
+
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
|
293
|
+
|
|
292
294
|
def init_weights_and_momentum(self):
|
|
293
295
|
params = TensorDict(dict(self.memory_model.named_parameters()))
|
|
294
296
|
|
|
@@ -306,6 +308,13 @@ class NeuralMemory(Module):
|
|
|
306
308
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
|
|
307
309
|
return_aux_kv_loss = False
|
|
308
310
|
):
|
|
311
|
+
seq_len = seq.shape[-2]
|
|
312
|
+
|
|
313
|
+
# handle edge case
|
|
314
|
+
|
|
315
|
+
if seq_len < self.chunk_size:
|
|
316
|
+
past_weight, _ = past_state
|
|
317
|
+
return TensorDict(past_weight).clone().zero_(), self.zero
|
|
309
318
|
|
|
310
319
|
seq = self.store_norm(seq)
|
|
311
320
|
|
|
@@ -425,12 +434,10 @@ class NeuralMemory(Module):
|
|
|
425
434
|
|
|
426
435
|
last_update = updates.apply(lambda t: t[:, -1])
|
|
427
436
|
|
|
428
|
-
next_state = (curr_weights + last_update, next_momentum)
|
|
429
|
-
|
|
430
437
|
if not return_aux_kv_loss:
|
|
431
|
-
return updates
|
|
438
|
+
return updates
|
|
432
439
|
|
|
433
|
-
return updates,
|
|
440
|
+
return updates, aux_kv_recon_loss.mean()
|
|
434
441
|
|
|
435
442
|
def retrieve_memories(
|
|
436
443
|
self,
|
|
@@ -442,7 +449,8 @@ class NeuralMemory(Module):
|
|
|
442
449
|
|
|
443
450
|
seq = self.retrieve_norm(seq)
|
|
444
451
|
|
|
445
|
-
|
|
452
|
+
if seq_len < self.chunk_size:
|
|
453
|
+
return self.init_empty_memory_embed(batch, seq_len)
|
|
446
454
|
|
|
447
455
|
seq = seq[:, (chunk_size - 1):]
|
|
448
456
|
curtailed_seq_len = seq.shape[-2]
|
|
@@ -524,7 +532,7 @@ class NeuralMemory(Module):
|
|
|
524
532
|
|
|
525
533
|
store_seq = default(store_seq, seq)
|
|
526
534
|
|
|
527
|
-
updates,
|
|
535
|
+
updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
|
|
528
536
|
|
|
529
537
|
past_weights, _ = past_state
|
|
530
538
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=R9Xu-BjqAe9ZY60IGk4aNXBx_L8THsjJ4QrkbTnRNHo,15346
|
|
4
|
+
titans_pytorch/titans.py,sha256=5wuAoDULbgXTM8Nbq8bXrW3Fd2nsn22kpERRfJOwZiU,16367
|
|
5
|
+
titans_pytorch-0.0.61.dist-info/METADATA,sha256=Cfhqnse_9nnFNqVGo9p_kxO_LVawwv4uuZOx4anqhf0,4457
|
|
6
|
+
titans_pytorch-0.0.61.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.61.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.61.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=kk8s8Q2WmbJxCVi8PcqSUyJBc8-CDAHrVjt6M0d_kFs,15323
|
|
4
|
-
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
-
titans_pytorch-0.0.58.dist-info/METADATA,sha256=a-Y6MV_89D44HlB7eKpurh-sw5DDiS-pIVei3Uw_uGE,4457
|
|
6
|
-
titans_pytorch-0.0.58.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.58.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.58.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|