titans-pytorch 0.0.46__py3-none-any.whl → 0.0.47__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +14 -10
- {titans_pytorch-0.0.46.dist-info → titans_pytorch-0.0.47.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.47.dist-info/RECORD +8 -0
- titans_pytorch-0.0.46.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.46.dist-info → titans_pytorch-0.0.47.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.46.dist-info → titans_pytorch-0.0.47.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
|
@@ -217,10 +217,10 @@ class NeuralMemory(Module):
|
|
|
217
217
|
def forward_and_loss(params, inputs, loss_weights, target):
|
|
218
218
|
pred = functional_call(self.memory_model, params, inputs)
|
|
219
219
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
|
220
|
-
|
|
221
|
-
return
|
|
220
|
+
weighted_loss = loss * loss_weights
|
|
221
|
+
return weighted_loss.sum(), loss.mean()
|
|
222
222
|
|
|
223
|
-
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
|
|
223
|
+
self.per_sample_grad_fn = vmap(grad(forward_and_loss, has_aux = True), in_dims = (None, 0, 0, 0))
|
|
224
224
|
|
|
225
225
|
# queries for retrieving from the model
|
|
226
226
|
|
|
@@ -282,7 +282,8 @@ class NeuralMemory(Module):
|
|
|
282
282
|
def store_memories(
|
|
283
283
|
self,
|
|
284
284
|
seq,
|
|
285
|
-
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
|
285
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
|
|
286
|
+
return_aux_kv_loss = False
|
|
286
287
|
):
|
|
287
288
|
|
|
288
289
|
seq = self.store_norm(seq)
|
|
@@ -330,7 +331,7 @@ class NeuralMemory(Module):
|
|
|
330
331
|
|
|
331
332
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
|
332
333
|
|
|
333
|
-
grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
|
|
334
|
+
grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
|
|
334
335
|
|
|
335
336
|
grads = TensorDict(grads)
|
|
336
337
|
|
|
@@ -405,7 +406,10 @@ class NeuralMemory(Module):
|
|
|
405
406
|
|
|
406
407
|
next_state = (curr_weights + last_update, next_momentum)
|
|
407
408
|
|
|
408
|
-
|
|
409
|
+
if not return_aux_kv_loss:
|
|
410
|
+
return updates, next_state
|
|
411
|
+
|
|
412
|
+
return updates, next_state, aux_kv_recon_loss
|
|
409
413
|
|
|
410
414
|
def retrieve_memories(
|
|
411
415
|
self,
|
|
@@ -484,7 +488,7 @@ class NeuralMemory(Module):
|
|
|
484
488
|
seq,
|
|
485
489
|
store_seq = None,
|
|
486
490
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
|
487
|
-
|
|
491
|
+
return_aux_kv_loss = False
|
|
488
492
|
):
|
|
489
493
|
batch, seq_len = seq.shape[:2]
|
|
490
494
|
|
|
@@ -499,13 +503,13 @@ class NeuralMemory(Module):
|
|
|
499
503
|
|
|
500
504
|
store_seq = default(store_seq, seq)
|
|
501
505
|
|
|
502
|
-
updates, next_memories = self.store_memories(store_seq, past_state)
|
|
506
|
+
updates, next_memories, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
|
|
503
507
|
|
|
504
508
|
past_weights, _ = past_state
|
|
505
509
|
|
|
506
510
|
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
|
507
511
|
|
|
508
|
-
if not
|
|
512
|
+
if not return_aux_kv_loss:
|
|
509
513
|
return retrieved
|
|
510
514
|
|
|
511
|
-
return retrieved,
|
|
515
|
+
return retrieved, aux_kv_recon_loss
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=mF8PYAjeAjLas1gkYybgzZX1AVK82A_ps_LY00ofYYs,9565
|
|
4
|
+
titans_pytorch/titans.py,sha256=ZcWxx6n-f8ttojRnK9fExavmT1bS-QSCRHQn7ldv7J0,15502
|
|
5
|
+
titans_pytorch-0.0.47.dist-info/METADATA,sha256=HjZxbJlnqsSgbioQz6KHWJb--8n18WDdL2T-jz-CFKc,4210
|
|
6
|
+
titans_pytorch-0.0.47.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.47.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.47.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=mF8PYAjeAjLas1gkYybgzZX1AVK82A_ps_LY00ofYYs,9565
|
|
4
|
-
titans_pytorch/titans.py,sha256=qxQ8pZCz8GEDhKeJMEaeAEzH66GAGVBNaRdNam_-czg,15260
|
|
5
|
-
titans_pytorch-0.0.46.dist-info/METADATA,sha256=Gg1-_Mmp9u_sJYEvaRt5GzKhhJoTNHjBL3efjSSDLL0,4210
|
|
6
|
-
titans_pytorch-0.0.46.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.46.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.46.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|