titans-pytorch 0.0.14__tar.gz → 0.0.16__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,24 @@
1
+ name: Tests the examples in README
2
+ on: [push, pull_request]
3
+
4
+ env:
5
+ TYPECHECK: True
6
+
7
+ jobs:
8
+ test:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - uses: actions/checkout@v4
12
+ - name: Install Python
13
+ uses: actions/setup-python@v5
14
+ with:
15
+ python-version: "3.11"
16
+ - name: Install dependencies
17
+ run: |
18
+ python -m pip install uv
19
+ python -m uv pip install --upgrade pip
20
+ python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
21
+ python -m uv pip install -e .[test]
22
+ - name: Test with pytest
23
+ run: |
24
+ python -m pytest tests/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.14
3
+ Version: 0.0.16
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -109,3 +109,13 @@ $ python train.py
109
109
  url = {https://api.semanticscholar.org/CorpusID:275212078}
110
110
  }
111
111
  ```
112
+
113
+ ```bibtex
114
+ @software{Kyrylov_Accelerated_Scan_2024,
115
+ author = {Kyrylov, Volodymyr},
116
+ doi = {10.5281/zenodo.10600962},
117
+ title = {Accelerated Scan},
118
+ version = {0.1.2},
119
+ year = {2024}
120
+ }
121
+ ```
@@ -58,3 +58,13 @@ $ python train.py
58
58
  url = {https://api.semanticscholar.org/CorpusID:275212078}
59
59
  }
60
60
  ```
61
+
62
+ ```bibtex
63
+ @software{Kyrylov_Accelerated_Scan_2024,
64
+ author = {Kyrylov, Volodymyr},
65
+ doi = {10.5281/zenodo.10600962},
66
+ title = {Accelerated Scan},
67
+ version = {0.1.2},
68
+ year = {2024}
69
+ }
70
+ ```
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.14"
3
+ version = "0.0.16"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,28 @@
1
+ import torch
2
+ import pytest
3
+
4
+ def test_titans():
5
+ from titans_pytorch import NeuralMemory
6
+
7
+ mem = NeuralMemory(
8
+ dim = 384,
9
+ chunk_size = 64,
10
+ )
11
+
12
+ seq = torch.randn(2, 1024, 384)
13
+ retrieved = mem(seq)
14
+
15
+ assert seq.shape == retrieved.shape
16
+
17
+ def test_titans_attn_memory():
18
+ from titans_pytorch.titans_attn_memory import NeuralMemory
19
+
20
+ mem = NeuralMemory(
21
+ dim = 384,
22
+ chunk_size = 64,
23
+ )
24
+
25
+ seq = torch.randn(2, 1024, 384)
26
+ retrieved = mem(seq)
27
+
28
+ assert seq.shape == retrieved.shape
@@ -1,3 +1,4 @@
1
1
  from titans_pytorch.titans import (
2
- NeuralMemory
2
+ NeuralMemory,
3
+ MemoryMLP,
3
4
  )
@@ -6,7 +6,7 @@ import torch
6
6
  from torch import nn, Tensor
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Linear, Module
9
- from torch.func import functional_call, vmap, grad_and_value
9
+ from torch.func import functional_call, vmap, grad
10
10
 
11
11
  from tensordict import TensorDict
12
12
 
@@ -57,7 +57,7 @@ def pack_one_with_inverse(t, pattern):
57
57
 
58
58
  # classes
59
59
 
60
- class MLP(Module):
60
+ class MemoryMLP(Module):
61
61
  def __init__(
62
62
  self,
63
63
  dim,
@@ -122,7 +122,7 @@ class NeuralMemory(Module):
122
122
  # memory mlp
123
123
 
124
124
  if not exists(model):
125
- model = MLP(dim_head, **default_mlp_kwargs)
125
+ model = MemoryMLP(dim_head, **default_mlp_kwargs)
126
126
 
127
127
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
128
128
 
@@ -141,7 +141,7 @@ class NeuralMemory(Module):
141
141
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
142
142
  return loss
143
143
 
144
- self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
144
+ self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
145
145
 
146
146
  # queries for retrieving from the model
147
147
 
@@ -235,7 +235,7 @@ class NeuralMemory(Module):
235
235
 
236
236
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
237
237
 
238
- grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
238
+ grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
239
239
 
240
240
  grads = TensorDict(grads)
241
241
 
@@ -305,7 +305,7 @@ class NeuralMemory(Module):
305
305
 
306
306
  next_state = (curr_weights + last_update, next_momentum)
307
307
 
308
- return updates, next_state, aux_store_loss.mean() / chunk_size
308
+ return updates, next_state
309
309
 
310
310
  def retrieve_memories(
311
311
  self,
@@ -382,6 +382,7 @@ class NeuralMemory(Module):
382
382
  def forward(
383
383
  self,
384
384
  seq,
385
+ store_seq = None,
385
386
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
386
387
  return_next_memories = False
387
388
  ):
@@ -396,7 +397,9 @@ class NeuralMemory(Module):
396
397
  if not exists(past_state):
397
398
  past_state = self.init_weights_and_momentum()
398
399
 
399
- updates, next_memories, aux_kv_mse_loss = self.store_memories(seq, past_state)
400
+ store_seq = default(store_seq, seq)
401
+
402
+ updates, next_memories = self.store_memories(store_seq, past_state)
400
403
 
401
404
  past_weights, _ = past_state
402
405
 
@@ -405,4 +408,4 @@ class NeuralMemory(Module):
405
408
  if not return_next_memories:
406
409
  return retrieved
407
410
 
408
- return retrieved, next_memories, aux_kv_mse_loss
411
+ return retrieved, next_memories
@@ -13,7 +13,11 @@ from local_attention import LocalTransformer
13
13
 
14
14
  from taylor_series_linear_attention import TaylorSeriesLinearAttn
15
15
 
16
- from titans_pytorch.titans import NeuralMemory
16
+ from titans_pytorch.titans import (
17
+ NeuralMemory,
18
+ MemoryAttention,
19
+ MemoryMLP
20
+ )
17
21
 
18
22
  # constants
19
23
 
File without changes