titans-pytorch 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
1
  from titans_pytorch.titans import (
2
- NeuralMemory
2
+ NeuralMemory,
3
+ MemoryMLP,
4
+ MemoryAttention
3
5
  )
titans_pytorch/titans.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
6
  from torch import nn, Tensor
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Linear, Module
9
- from torch.func import functional_call, vmap, grad_and_value
9
+ from torch.func import functional_call, vmap, grad
10
10
 
11
11
  from tensordict import TensorDict
12
12
 
@@ -57,7 +57,7 @@ def pack_one_with_inverse(t, pattern):
57
57
 
58
58
  # classes
59
59
 
60
- class MLP(Module):
60
+ class MemoryMLP(Module):
61
61
  def __init__(
62
62
  self,
63
63
  dim,
@@ -80,6 +80,35 @@ class MLP(Module):
80
80
 
81
81
  return x
82
82
 
83
+ # improvised attention as memory module
84
+ # todo - expand if see signal in experiments
85
+
86
+ class MemoryAttention(Module):
87
+ def __init__(
88
+ self,
89
+ dim
90
+ ):
91
+ super().__init__()
92
+ self.weights = nn.ParameterList([
93
+ nn.Parameter(torch.randn(dim, dim)), # queries
94
+ nn.Parameter(torch.randn(dim, dim)), # keys
95
+ nn.Parameter(torch.randn(dim, dim)), # values
96
+ ])
97
+
98
+ def forward(self, x):
99
+ wq, wk, wv = self.weights
100
+
101
+ q = x @ wq
102
+ k = x @ wk
103
+ v = x @ wv
104
+
105
+ sim = q @ k.transpose(-1, -2)
106
+
107
+ attn = sim.softmax(dim = -1)
108
+
109
+ out = attn @ v
110
+ return out
111
+
83
112
  # main neural memory
84
113
 
85
114
  def default_loss_fn(pred, target):
@@ -122,7 +151,7 @@ class NeuralMemory(Module):
122
151
  # memory mlp
123
152
 
124
153
  if not exists(model):
125
- model = MLP(dim_head, **default_mlp_kwargs)
154
+ model = MemoryMLP(dim_head, **default_mlp_kwargs)
126
155
 
127
156
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
128
157
 
@@ -141,7 +170,7 @@ class NeuralMemory(Module):
141
170
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
142
171
  return loss
143
172
 
144
- self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
173
+ self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
145
174
 
146
175
  # queries for retrieving from the model
147
176
 
@@ -235,7 +264,7 @@ class NeuralMemory(Module):
235
264
 
236
265
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
237
266
 
238
- grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
267
+ grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
239
268
 
240
269
  grads = TensorDict(grads)
241
270
 
@@ -305,7 +334,7 @@ class NeuralMemory(Module):
305
334
 
306
335
  next_state = (curr_weights + last_update, next_momentum)
307
336
 
308
- return updates, next_state, aux_store_loss.mean() / chunk_size
337
+ return updates, next_state
309
338
 
310
339
  def retrieve_memories(
311
340
  self,
@@ -382,6 +411,7 @@ class NeuralMemory(Module):
382
411
  def forward(
383
412
  self,
384
413
  seq,
414
+ store_seq = None,
385
415
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
386
416
  return_next_memories = False
387
417
  ):
@@ -396,7 +426,9 @@ class NeuralMemory(Module):
396
426
  if not exists(past_state):
397
427
  past_state = self.init_weights_and_momentum()
398
428
 
399
- updates, next_memories, aux_kv_mse_loss = self.store_memories(seq, past_state)
429
+ store_seq = default(store_seq, seq)
430
+
431
+ updates, next_memories = self.store_memories(store_seq, past_state)
400
432
 
401
433
  past_weights, _ = past_state
402
434
 
@@ -405,4 +437,4 @@ class NeuralMemory(Module):
405
437
  if not return_next_memories:
406
438
  return retrieved
407
439
 
408
- return retrieved, next_memories, aux_kv_mse_loss
440
+ return retrieved, next_memories
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.14
3
+ Version: 0.0.15
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -109,3 +109,13 @@ $ python train.py
109
109
  url = {https://api.semanticscholar.org/CorpusID:275212078}
110
110
  }
111
111
  ```
112
+
113
+ ```bibtex
114
+ @software{Kyrylov_Accelerated_Scan_2024,
115
+ author = {Kyrylov, Volodymyr},
116
+ doi = {10.5281/zenodo.10600962},
117
+ title = {Accelerated Scan},
118
+ version = {0.1.2},
119
+ year = {2024}
120
+ }
121
+ ```
@@ -0,0 +1,7 @@
1
+ titans_pytorch/__init__.py,sha256=bu8p8kUA24EVrTz-ojixHTwV-6KTY9Y0cNJkaMW4Whw,91
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=ArFYKgI0p7N3mmv8b4ncxkl3gkKAXWrFrnl2quh2RqE,12930
4
+ titans_pytorch-0.0.15.dist-info/METADATA,sha256=3IC7BT7J3BYx23wUOYeuGgtz769dsIczBnWbm6oi0Tw,3811
5
+ titans_pytorch-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ titans_pytorch-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ titans_pytorch-0.0.15.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=O4WO2I6GhxyZobqbgfFx01saFEKUxhA0BBwt70m2yeQ,12306
4
- titans_pytorch-0.0.14.dist-info/METADATA,sha256=HKDSJ3sWc54sN1_fOEYU7i5TjiQff49vcZG9G8EU6z4,3598
5
- titans_pytorch-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- titans_pytorch-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- titans_pytorch-0.0.14.dist-info/RECORD,,