titans-pytorch 0.0.57__py3-none-any.whl → 0.0.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +14 -13
- titans_pytorch/titans.py +14 -6
- {titans_pytorch-0.0.57.dist-info → titans_pytorch-0.0.61.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.61.dist-info/RECORD +8 -0
- titans_pytorch-0.0.57.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.57.dist-info → titans_pytorch-0.0.61.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.57.dist-info → titans_pytorch-0.0.61.dist-info}/licenses/LICENSE +0 -0
|
@@ -24,8 +24,8 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
|
|
|
24
24
|
|
|
25
25
|
def create_mac_mask(b, h, q_idx, kv_idx):
|
|
26
26
|
is_persist_mem = kv_idx < persist_mem_len
|
|
27
|
-
causal_mask = q_idx >= (kv_idx -
|
|
28
|
-
block_diagonal = (q_idx // window_size) == ((kv_idx -
|
|
27
|
+
causal_mask = q_idx >= (kv_idx - persist_mem_len)
|
|
28
|
+
block_diagonal = (q_idx // window_size) == ((kv_idx - persist_mem_len) // window_size)
|
|
29
29
|
return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
|
|
30
30
|
|
|
31
31
|
block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
|
|
@@ -292,12 +292,13 @@ class NeuralMemoryGatingWrapper(Module):
|
|
|
292
292
|
self,
|
|
293
293
|
dim,
|
|
294
294
|
attn: SegmentedAttention,
|
|
295
|
-
neural_mem: NeuralMemory | None = None
|
|
295
|
+
neural_mem: NeuralMemory | None = None,
|
|
296
|
+
gate_attn_output = True
|
|
296
297
|
):
|
|
297
298
|
super().__init__()
|
|
298
299
|
self.attn = attn
|
|
299
300
|
self.neural_mem = neural_mem
|
|
300
|
-
self.
|
|
301
|
+
self.gate_attn_output = gate_attn_output
|
|
301
302
|
|
|
302
303
|
def forward(
|
|
303
304
|
self,
|
|
@@ -313,21 +314,19 @@ class NeuralMemoryGatingWrapper(Module):
|
|
|
313
314
|
|
|
314
315
|
# initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory
|
|
315
316
|
|
|
316
|
-
retrieved,
|
|
317
|
+
retrieved, kv_aux_loss = mem(seq, return_aux_kv_loss = True)
|
|
317
318
|
|
|
318
|
-
|
|
319
|
+
if not self.gate_attn_output:
|
|
320
|
+
seq = seq + retrieved
|
|
319
321
|
|
|
320
322
|
# attention
|
|
321
323
|
|
|
322
324
|
attn_out, values = self.attn(seq, *args, **kwargs)
|
|
323
325
|
|
|
324
|
-
|
|
326
|
+
if self.gate_attn_output:
|
|
327
|
+
attn_out = attn_out * retrieved.sigmoid()
|
|
325
328
|
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
attn_out = attn_out * self.to_gates(retrieved).sigmoid()
|
|
329
|
-
|
|
330
|
-
return (attn_out, values), (first_kv_aux_loss + second_kv_aux_loss)
|
|
329
|
+
return (attn_out, values), kv_aux_loss
|
|
331
330
|
|
|
332
331
|
# MAC transformer
|
|
333
332
|
|
|
@@ -340,6 +339,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
340
339
|
depth,
|
|
341
340
|
segment_len,
|
|
342
341
|
neural_memory_segment_len = None,
|
|
342
|
+
neural_mem_gate_attn_output = True,
|
|
343
343
|
num_longterm_mem_tokens = 0,
|
|
344
344
|
num_persist_mem_tokens = 0,
|
|
345
345
|
dim_head = 64,
|
|
@@ -414,6 +414,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
414
414
|
dim,
|
|
415
415
|
attn = attn,
|
|
416
416
|
neural_mem = mem,
|
|
417
|
+
gate_attn_output = neural_mem_gate_attn_output
|
|
417
418
|
)
|
|
418
419
|
|
|
419
420
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
|
@@ -488,7 +489,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
488
489
|
flex_attn_fn = None
|
|
489
490
|
|
|
490
491
|
if use_flex_attn:
|
|
491
|
-
block_mask = create_mac_block_mask(seq_len_with_mem,
|
|
492
|
+
block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens)
|
|
492
493
|
flex_attn_fn = partial(flex_attention, block_mask = block_mask)
|
|
493
494
|
|
|
494
495
|
# value residual
|
titans_pytorch/titans.py
CHANGED
|
@@ -289,6 +289,8 @@ class NeuralMemory(Module):
|
|
|
289
289
|
|
|
290
290
|
self.use_accelerated_scan = use_accelerated_scan
|
|
291
291
|
|
|
292
|
+
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
|
293
|
+
|
|
292
294
|
def init_weights_and_momentum(self):
|
|
293
295
|
params = TensorDict(dict(self.memory_model.named_parameters()))
|
|
294
296
|
|
|
@@ -306,6 +308,13 @@ class NeuralMemory(Module):
|
|
|
306
308
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
|
|
307
309
|
return_aux_kv_loss = False
|
|
308
310
|
):
|
|
311
|
+
seq_len = seq.shape[-2]
|
|
312
|
+
|
|
313
|
+
# handle edge case
|
|
314
|
+
|
|
315
|
+
if seq_len < self.chunk_size:
|
|
316
|
+
past_weight, _ = past_state
|
|
317
|
+
return TensorDict(past_weight).clone().zero_(), self.zero
|
|
309
318
|
|
|
310
319
|
seq = self.store_norm(seq)
|
|
311
320
|
|
|
@@ -425,12 +434,10 @@ class NeuralMemory(Module):
|
|
|
425
434
|
|
|
426
435
|
last_update = updates.apply(lambda t: t[:, -1])
|
|
427
436
|
|
|
428
|
-
next_state = (curr_weights + last_update, next_momentum)
|
|
429
|
-
|
|
430
437
|
if not return_aux_kv_loss:
|
|
431
|
-
return updates
|
|
438
|
+
return updates
|
|
432
439
|
|
|
433
|
-
return updates,
|
|
440
|
+
return updates, aux_kv_recon_loss.mean()
|
|
434
441
|
|
|
435
442
|
def retrieve_memories(
|
|
436
443
|
self,
|
|
@@ -442,7 +449,8 @@ class NeuralMemory(Module):
|
|
|
442
449
|
|
|
443
450
|
seq = self.retrieve_norm(seq)
|
|
444
451
|
|
|
445
|
-
|
|
452
|
+
if seq_len < self.chunk_size:
|
|
453
|
+
return self.init_empty_memory_embed(batch, seq_len)
|
|
446
454
|
|
|
447
455
|
seq = seq[:, (chunk_size - 1):]
|
|
448
456
|
curtailed_seq_len = seq.shape[-2]
|
|
@@ -524,7 +532,7 @@ class NeuralMemory(Module):
|
|
|
524
532
|
|
|
525
533
|
store_seq = default(store_seq, seq)
|
|
526
534
|
|
|
527
|
-
updates,
|
|
535
|
+
updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
|
|
528
536
|
|
|
529
537
|
past_weights, _ = past_state
|
|
530
538
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=R9Xu-BjqAe9ZY60IGk4aNXBx_L8THsjJ4QrkbTnRNHo,15346
|
|
4
|
+
titans_pytorch/titans.py,sha256=5wuAoDULbgXTM8Nbq8bXrW3Fd2nsn22kpERRfJOwZiU,16367
|
|
5
|
+
titans_pytorch-0.0.61.dist-info/METADATA,sha256=Cfhqnse_9nnFNqVGo9p_kxO_LVawwv4uuZOx4anqhf0,4457
|
|
6
|
+
titans_pytorch-0.0.61.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.61.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.61.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=-WS_SI--_5f1whMaJOH-mYCU37EjYU_iZTurGfs8zgI,15331
|
|
4
|
-
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
-
titans_pytorch-0.0.57.dist-info/METADATA,sha256=rwLIRndtBo22oJt0Xm9xK9zqOYV50Jfo6g7oVrKq7CQ,4457
|
|
6
|
-
titans_pytorch-0.0.57.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.57.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.57.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|