titans-pytorch 0.0.45__py3-none-any.whl → 0.0.47__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +2 -2
- titans_pytorch/titans.py +14 -14
- {titans_pytorch-0.0.45.dist-info → titans_pytorch-0.0.47.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.47.dist-info/RECORD +8 -0
- titans_pytorch-0.0.45.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.45.dist-info → titans_pytorch-0.0.47.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.45.dist-info → titans_pytorch-0.0.47.dist-info}/licenses/LICENSE +0 -0
|
@@ -311,7 +311,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
311
311
|
|
|
312
312
|
pos_emb = self.axial_pos_emb((neural_mem_windows, neural_mem_segment_len), flatten = True)
|
|
313
313
|
|
|
314
|
-
|
|
314
|
+
x = x + pos_emb[:seq_len_with_mem]
|
|
315
315
|
|
|
316
316
|
# value residual
|
|
317
317
|
|
|
@@ -324,7 +324,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
324
324
|
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
|
325
325
|
|
|
326
326
|
if exists(maybe_neural_mem):
|
|
327
|
-
x = maybe_neural_mem(x
|
|
327
|
+
x = maybe_neural_mem(x)
|
|
328
328
|
|
|
329
329
|
x, values = attn(x, value_residual = value_residual)
|
|
330
330
|
|
titans_pytorch/titans.py
CHANGED
|
@@ -217,10 +217,10 @@ class NeuralMemory(Module):
|
|
|
217
217
|
def forward_and_loss(params, inputs, loss_weights, target):
|
|
218
218
|
pred = functional_call(self.memory_model, params, inputs)
|
|
219
219
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
|
220
|
-
|
|
221
|
-
return
|
|
220
|
+
weighted_loss = loss * loss_weights
|
|
221
|
+
return weighted_loss.sum(), loss.mean()
|
|
222
222
|
|
|
223
|
-
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
|
|
223
|
+
self.per_sample_grad_fn = vmap(grad(forward_and_loss, has_aux = True), in_dims = (None, 0, 0, 0))
|
|
224
224
|
|
|
225
225
|
# queries for retrieving from the model
|
|
226
226
|
|
|
@@ -282,7 +282,8 @@ class NeuralMemory(Module):
|
|
|
282
282
|
def store_memories(
|
|
283
283
|
self,
|
|
284
284
|
seq,
|
|
285
|
-
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
|
285
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
|
|
286
|
+
return_aux_kv_loss = False
|
|
286
287
|
):
|
|
287
288
|
|
|
288
289
|
seq = self.store_norm(seq)
|
|
@@ -330,7 +331,7 @@ class NeuralMemory(Module):
|
|
|
330
331
|
|
|
331
332
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
|
332
333
|
|
|
333
|
-
grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
|
|
334
|
+
grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
|
|
334
335
|
|
|
335
336
|
grads = TensorDict(grads)
|
|
336
337
|
|
|
@@ -405,7 +406,10 @@ class NeuralMemory(Module):
|
|
|
405
406
|
|
|
406
407
|
next_state = (curr_weights + last_update, next_momentum)
|
|
407
408
|
|
|
408
|
-
|
|
409
|
+
if not return_aux_kv_loss:
|
|
410
|
+
return updates, next_state
|
|
411
|
+
|
|
412
|
+
return updates, next_state, aux_kv_recon_loss
|
|
409
413
|
|
|
410
414
|
def retrieve_memories(
|
|
411
415
|
self,
|
|
@@ -484,14 +488,10 @@ class NeuralMemory(Module):
|
|
|
484
488
|
seq,
|
|
485
489
|
store_seq = None,
|
|
486
490
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
|
487
|
-
|
|
488
|
-
pos_emb: Tensor | None = None
|
|
491
|
+
return_aux_kv_loss = False
|
|
489
492
|
):
|
|
490
493
|
batch, seq_len = seq.shape[:2]
|
|
491
494
|
|
|
492
|
-
if exists(pos_emb):
|
|
493
|
-
seq = seq + pos_emb
|
|
494
|
-
|
|
495
495
|
if seq_len < self.chunk_size:
|
|
496
496
|
return self.init_empty_memory_embed(batch, seq_len)
|
|
497
497
|
|
|
@@ -503,13 +503,13 @@ class NeuralMemory(Module):
|
|
|
503
503
|
|
|
504
504
|
store_seq = default(store_seq, seq)
|
|
505
505
|
|
|
506
|
-
updates, next_memories = self.store_memories(store_seq, past_state)
|
|
506
|
+
updates, next_memories, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
|
|
507
507
|
|
|
508
508
|
past_weights, _ = past_state
|
|
509
509
|
|
|
510
510
|
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
|
511
511
|
|
|
512
|
-
if not
|
|
512
|
+
if not return_aux_kv_loss:
|
|
513
513
|
return retrieved
|
|
514
514
|
|
|
515
|
-
return retrieved,
|
|
515
|
+
return retrieved, aux_kv_recon_loss
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=mF8PYAjeAjLas1gkYybgzZX1AVK82A_ps_LY00ofYYs,9565
|
|
4
|
+
titans_pytorch/titans.py,sha256=ZcWxx6n-f8ttojRnK9fExavmT1bS-QSCRHQn7ldv7J0,15502
|
|
5
|
+
titans_pytorch-0.0.47.dist-info/METADATA,sha256=HjZxbJlnqsSgbioQz6KHWJb--8n18WDdL2T-jz-CFKc,4210
|
|
6
|
+
titans_pytorch-0.0.47.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.47.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.47.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=SFB7sXDt1bYpwt_PVrXM0-1vXKEemBTAfnfboU66A7M,9586
|
|
4
|
-
titans_pytorch/titans.py,sha256=7LZIbaavC0bk85UBPzNzZP6YxKeFb0ujZ9k5IU048aI,15360
|
|
5
|
-
titans_pytorch-0.0.45.dist-info/METADATA,sha256=EqrDXchEvzFbz1BqSdAB8HkPMjRY3KYyBSu16hbKTUs,4210
|
|
6
|
-
titans_pytorch-0.0.45.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.45.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.45.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|