titans-pytorch 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,3 +1,5 @@
1
1
  from titans_pytorch.titans import (
2
- NeuralMemory
2
+ NeuralMemory,
3
+ MemoryMLP,
4
+ MemoryAttention
3
5
  )
titans_pytorch/titans.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
6
  from torch import nn, Tensor
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Linear, Module
9
- from torch.func import functional_call, vmap, grad_and_value
9
+ from torch.func import functional_call, vmap, grad
10
10
 
11
11
  from tensordict import TensorDict
12
12
 
@@ -57,7 +57,7 @@ def pack_one_with_inverse(t, pattern):
57
57
 
58
58
  # classes
59
59
 
60
- class MLP(Module):
60
+ class MemoryMLP(Module):
61
61
  def __init__(
62
62
  self,
63
63
  dim,
@@ -80,6 +80,35 @@ class MLP(Module):
80
80
 
81
81
  return x
82
82
 
83
+ # improvised attention as memory module
84
+ # todo - expand if see signal in experiments
85
+
86
+ class MemoryAttention(Module):
87
+ def __init__(
88
+ self,
89
+ dim
90
+ ):
91
+ super().__init__()
92
+ self.weights = nn.ParameterList([
93
+ nn.Parameter(torch.randn(dim, dim)), # queries
94
+ nn.Parameter(torch.randn(dim, dim)), # keys
95
+ nn.Parameter(torch.randn(dim, dim)), # values
96
+ ])
97
+
98
+ def forward(self, x):
99
+ wq, wk, wv = self.weights
100
+
101
+ q = x @ wq
102
+ k = x @ wk
103
+ v = x @ wv
104
+
105
+ sim = q @ k.transpose(-1, -2)
106
+
107
+ attn = sim.softmax(dim = -1)
108
+
109
+ out = attn @ v
110
+ return out
111
+
83
112
  # main neural memory
84
113
 
85
114
  def default_loss_fn(pred, target):
@@ -122,7 +151,7 @@ class NeuralMemory(Module):
122
151
  # memory mlp
123
152
 
124
153
  if not exists(model):
125
- model = MLP(dim_head, **default_mlp_kwargs)
154
+ model = MemoryMLP(dim_head, **default_mlp_kwargs)
126
155
 
127
156
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
128
157
 
@@ -141,7 +170,7 @@ class NeuralMemory(Module):
141
170
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
142
171
  return loss
143
172
 
144
- self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
173
+ self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
145
174
 
146
175
  # queries for retrieving from the model
147
176
 
@@ -235,7 +264,7 @@ class NeuralMemory(Module):
235
264
 
236
265
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
237
266
 
238
- grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
267
+ grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
239
268
 
240
269
  grads = TensorDict(grads)
241
270
 
@@ -305,7 +334,7 @@ class NeuralMemory(Module):
305
334
 
306
335
  next_state = (curr_weights + last_update, next_momentum)
307
336
 
308
- return updates, next_state, aux_store_loss.mean() / chunk_size
337
+ return updates, next_state
309
338
 
310
339
  def retrieve_memories(
311
340
  self,
@@ -382,6 +411,7 @@ class NeuralMemory(Module):
382
411
  def forward(
383
412
  self,
384
413
  seq,
414
+ store_seq = None,
385
415
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
386
416
  return_next_memories = False
387
417
  ):
@@ -396,7 +426,9 @@ class NeuralMemory(Module):
396
426
  if not exists(past_state):
397
427
  past_state = self.init_weights_and_momentum()
398
428
 
399
- updates, next_memories, aux_kv_mse_loss = self.store_memories(seq, past_state)
429
+ store_seq = default(store_seq, seq)
430
+
431
+ updates, next_memories = self.store_memories(store_seq, past_state)
400
432
 
401
433
  past_weights, _ = past_state
402
434
 
@@ -405,4 +437,4 @@ class NeuralMemory(Module):
405
437
  if not return_next_memories:
406
438
  return retrieved
407
439
 
408
- return retrieved, next_memories, aux_kv_mse_loss
440
+ return retrieved, next_memories
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.14
3
+ Version: 0.0.15
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -109,3 +109,13 @@ $ python train.py
109
109
  url = {https://api.semanticscholar.org/CorpusID:275212078}
110
110
  }
111
111
  ```
112
+
113
+ ```bibtex
114
+ @software{Kyrylov_Accelerated_Scan_2024,
115
+ author = {Kyrylov, Volodymyr},
116
+ doi = {10.5281/zenodo.10600962},
117
+ title = {Accelerated Scan},
118
+ version = {0.1.2},
119
+ year = {2024}
120
+ }
121
+ ```
@@ -0,0 +1,7 @@
1
+ titans_pytorch/__init__.py,sha256=bu8p8kUA24EVrTz-ojixHTwV-6KTY9Y0cNJkaMW4Whw,91
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=ArFYKgI0p7N3mmv8b4ncxkl3gkKAXWrFrnl2quh2RqE,12930
4
+ titans_pytorch-0.0.15.dist-info/METADATA,sha256=3IC7BT7J3BYx23wUOYeuGgtz769dsIczBnWbm6oi0Tw,3811
5
+ titans_pytorch-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ titans_pytorch-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ titans_pytorch-0.0.15.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=O4WO2I6GhxyZobqbgfFx01saFEKUxhA0BBwt70m2yeQ,12306
4
- titans_pytorch-0.0.14.dist-info/METADATA,sha256=HKDSJ3sWc54sN1_fOEYU7i5TjiQff49vcZG9G8EU6z4,3598
5
- titans_pytorch-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- titans_pytorch-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- titans_pytorch-0.0.14.dist-info/RECORD,,