titans-pytorch 0.0.53__tar.gz → 0.0.54__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.53
3
+ Version: 0.0.54
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.53"
3
+ version = "0.0.54"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -6,17 +6,20 @@ from titans_pytorch import NeuralMemory
6
6
 
7
7
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
8
8
  @pytest.mark.parametrize('silu', (False, True))
9
+ @pytest.mark.parametrize('learned_mem_model_weights', (False, True))
9
10
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
10
11
  def test_titans(
11
12
  seq_len,
12
13
  silu,
14
+ learned_mem_model_weights,
13
15
  max_grad_norm,
14
16
  ):
15
17
  mem = NeuralMemory(
16
18
  dim = 384,
17
19
  chunk_size = 64,
18
20
  activation = nn.SiLU() if silu else None,
19
- max_grad_norm = max_grad_norm
21
+ max_grad_norm = max_grad_norm,
22
+ learned_mem_model_weights = learned_mem_model_weights
20
23
  )
21
24
 
22
25
  seq = torch.randn(2, seq_len, 384)
@@ -1,6 +1,7 @@
1
1
  from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
+ MemoryAttention
4
5
  )
5
6
 
6
7
  from titans_pytorch.mac_transformer import (
@@ -123,9 +123,12 @@ class MemoryMLP(Module):
123
123
  class MemoryAttention(Module):
124
124
  def __init__(
125
125
  self,
126
- dim
126
+ dim,
127
+ scale = 8.
127
128
  ):
128
129
  super().__init__()
130
+ self.scale = scale
131
+
129
132
  self.weights = nn.ParameterList([
130
133
  nn.Parameter(torch.randn(dim, dim)), # queries
131
134
  nn.Parameter(torch.randn(dim, dim)), # keys
@@ -143,6 +146,7 @@ class MemoryAttention(Module):
143
146
 
144
147
  attn_out = F.scaled_dot_product_attention(
145
148
  q, k, v,
149
+ scale = self.scale,
146
150
  is_causal = True
147
151
  )
148
152
 
@@ -174,6 +178,7 @@ class NeuralMemory(Module):
174
178
  default_step_transform_max_lr = 1e-2,
175
179
  pre_rmsnorm = True,
176
180
  post_rmsnorm = True,
181
+ learned_mem_model_weights = True,
177
182
  max_grad_norm: float | None = None,
178
183
  use_accelerated_scan = False,
179
184
  activation: Module | None = None,
@@ -212,6 +217,9 @@ class NeuralMemory(Module):
212
217
  if not exists(model):
213
218
  model = MemoryMLP(dim_head, **default_model_kwargs)
214
219
 
220
+ if not learned_mem_model_weights:
221
+ model.requires_grad_(False)
222
+
215
223
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
216
224
 
217
225
  # the memory is the weights of the model
@@ -9,7 +9,7 @@ from torch.optim import Adam
9
9
  from torch.nn import functional as F
10
10
  from torch.utils.data import DataLoader, Dataset
11
11
 
12
- from titans_pytorch.mac_transformer import MemoryAsContextTransformer
12
+ from titans_pytorch import MemoryAsContextTransformer
13
13
 
14
14
  # constants
15
15
 
@@ -25,13 +25,14 @@ SHOULD_GENERATE = True
25
25
  SEQ_LEN = 512
26
26
 
27
27
  PROJECT_NAME = 'titans-mac-transformer'
28
- WANDB_ONLINE = False # turn this on to pipe experiment to cloud
28
+ WANDB_ONLINE = True # turn this on to pipe experiment to cloud
29
29
  NEURAL_MEMORY_DEPTH = 2
30
30
  NUM_PERSIST_MEM = 4
31
31
  NUM_LONGTERM_MEM = 4
32
32
  NEURAL_MEM_LAYERS = (2, 4)
33
33
  WINDOW_SIZE = 32
34
34
  KV_RECON_LOSS_WEIGHT = 0.
35
+ LEARNED_MEM_MODEL_WEIGHTS = True
35
36
  RUN_NAME = f'mac - {NUM_LONGTERM_MEM} longterm mems, layers {NEURAL_MEM_LAYERS}'
36
37
 
37
38
  # wandb experiment tracker
@@ -115,6 +116,7 @@ model = MemoryAsContextTransformer(
115
116
  neural_memory_kwargs = dict(
116
117
  dim_head = 64,
117
118
  heads = 4,
119
+ learned_mem_model_weights = LEARNED_MEM_MODEL_WEIGHTS,
118
120
  default_model_kwargs = dict(
119
121
  depth = NEURAL_MEMORY_DEPTH,
120
122
  )
File without changes