titans-pytorch 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,4 @@
1
1
  from titans_pytorch.titans import (
2
- NeuralMemory
2
+ NeuralMemory,
3
+ MemoryMLP,
3
4
  )
titans_pytorch/titans.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
6
  from torch import nn, Tensor
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Linear, Module
9
- from torch.func import functional_call, vmap, grad_and_value
9
+ from torch.func import functional_call, vmap, grad
10
10
 
11
11
  from tensordict import TensorDict
12
12
 
@@ -57,7 +57,7 @@ def pack_one_with_inverse(t, pattern):
57
57
 
58
58
  # classes
59
59
 
60
- class MLP(Module):
60
+ class MemoryMLP(Module):
61
61
  def __init__(
62
62
  self,
63
63
  dim,
@@ -122,7 +122,7 @@ class NeuralMemory(Module):
122
122
  # memory mlp
123
123
 
124
124
  if not exists(model):
125
- model = MLP(dim_head, **default_mlp_kwargs)
125
+ model = MemoryMLP(dim_head, **default_mlp_kwargs)
126
126
 
127
127
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
128
128
 
@@ -141,7 +141,7 @@ class NeuralMemory(Module):
141
141
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
142
142
  return loss
143
143
 
144
- self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
144
+ self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
145
145
 
146
146
  # queries for retrieving from the model
147
147
 
@@ -235,7 +235,7 @@ class NeuralMemory(Module):
235
235
 
236
236
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
237
237
 
238
- grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
238
+ grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
239
239
 
240
240
  grads = TensorDict(grads)
241
241
 
@@ -305,7 +305,7 @@ class NeuralMemory(Module):
305
305
 
306
306
  next_state = (curr_weights + last_update, next_momentum)
307
307
 
308
- return updates, next_state, aux_store_loss.mean() / chunk_size
308
+ return updates, next_state
309
309
 
310
310
  def retrieve_memories(
311
311
  self,
@@ -382,6 +382,7 @@ class NeuralMemory(Module):
382
382
  def forward(
383
383
  self,
384
384
  seq,
385
+ store_seq = None,
385
386
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
386
387
  return_next_memories = False
387
388
  ):
@@ -396,7 +397,9 @@ class NeuralMemory(Module):
396
397
  if not exists(past_state):
397
398
  past_state = self.init_weights_and_momentum()
398
399
 
399
- updates, next_memories, aux_kv_mse_loss = self.store_memories(seq, past_state)
400
+ store_seq = default(store_seq, seq)
401
+
402
+ updates, next_memories = self.store_memories(store_seq, past_state)
400
403
 
401
404
  past_weights, _ = past_state
402
405
 
@@ -405,4 +408,4 @@ class NeuralMemory(Module):
405
408
  if not return_next_memories:
406
409
  return retrieved
407
410
 
408
- return retrieved, next_memories, aux_kv_mse_loss
411
+ return retrieved, next_memories
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.14
3
+ Version: 0.0.16
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -109,3 +109,13 @@ $ python train.py
109
109
  url = {https://api.semanticscholar.org/CorpusID:275212078}
110
110
  }
111
111
  ```
112
+
113
+ ```bibtex
114
+ @software{Kyrylov_Accelerated_Scan_2024,
115
+ author = {Kyrylov, Volodymyr},
116
+ doi = {10.5281/zenodo.10600962},
117
+ title = {Accelerated Scan},
118
+ version = {0.1.2},
119
+ year = {2024}
120
+ }
121
+ ```
@@ -0,0 +1,7 @@
1
+ titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=HYm0R_1w3s8MNPsyE2qAVpHGqTBX_AoWtjzxRfF1Ams,12269
4
+ titans_pytorch-0.0.16.dist-info/METADATA,sha256=mzSgA4okWWSh97fncN2UKJaTVa3PWOHYVaFkQ1Ker0w,3811
5
+ titans_pytorch-0.0.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ titans_pytorch-0.0.16.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ titans_pytorch-0.0.16.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=O4WO2I6GhxyZobqbgfFx01saFEKUxhA0BBwt70m2yeQ,12306
4
- titans_pytorch-0.0.14.dist-info/METADATA,sha256=HKDSJ3sWc54sN1_fOEYU7i5TjiQff49vcZG9G8EU6z4,3598
5
- titans_pytorch-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- titans_pytorch-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- titans_pytorch-0.0.14.dist-info/RECORD,,