titans-pytorch 0.1.2__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.2
3
+ Version: 0.1.6
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -46,6 +46,7 @@ Requires-Dist: torch>=2.2
46
46
  Requires-Dist: tqdm
47
47
  Requires-Dist: x-transformers
48
48
  Provides-Extra: examples
49
+ Requires-Dist: adam-atan2-pytorch>=0.1.18; extra == 'examples'
49
50
  Requires-Dist: wandb; extra == 'examples'
50
51
  Provides-Extra: test
51
52
  Requires-Dist: pytest; extra == 'test'
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.2"
3
+ version = "0.1.6"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -45,6 +45,7 @@ Repository = "https://github.com/lucidrains/titans-pytorch"
45
45
  [project.optional-dependencies]
46
46
 
47
47
  examples = [
48
+ "adam-atan2-pytorch>=0.1.18",
48
49
  "wandb"
49
50
  ]
50
51
 
@@ -11,12 +11,14 @@ def exists(v):
11
11
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
12
12
  @pytest.mark.parametrize('silu', (False, True))
13
13
  @pytest.mark.parametrize('learned_mem_model_weights', (False, True))
14
+ @pytest.mark.parametrize('attn_pool_chunks', (False, True))
14
15
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
15
16
  @pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
16
17
  def test_titans(
17
18
  seq_len,
18
19
  silu,
19
20
  learned_mem_model_weights,
21
+ attn_pool_chunks,
20
22
  max_grad_norm,
21
23
  per_parameter_lr_modulation
22
24
  ):
@@ -24,6 +26,7 @@ def test_titans(
24
26
  dim = 384,
25
27
  chunk_size = 64,
26
28
  activation = nn.SiLU() if silu else None,
29
+ attn_pool_chunks = attn_pool_chunks,
27
30
  max_grad_norm = max_grad_norm,
28
31
  per_parameter_lr_modulation = per_parameter_lr_modulation,
29
32
  learned_mem_model_weights = learned_mem_model_weights
@@ -2,7 +2,8 @@ from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
4
  MemoryAttention,
5
- FactorizedMemoryMLP
5
+ FactorizedMemoryMLP,
6
+ GatedResidualMemoryMLP
6
7
  )
7
8
 
8
9
  from titans_pytorch.mac_transformer import (
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
  from typing import Callable
3
+
3
4
  import math
4
5
  from functools import partial
5
6
 
6
7
  import torch
7
- from torch import nn, Tensor
8
+ from torch import nn, cat, Tensor
8
9
  import torch.nn.functional as F
9
10
  from torch.nn import Linear, Module, Parameter, ParameterList
10
11
  from torch.func import functional_call, vmap, grad
@@ -18,7 +19,7 @@ from titans_pytorch.associative_scan import (
18
19
  )
19
20
 
20
21
  import einx
21
- from einops import rearrange, repeat, pack, unpack
22
+ from einops import rearrange, repeat, reduce, pack, unpack
22
23
  from einops.layers.torch import Rearrange, Reduce
23
24
 
24
25
  """
@@ -95,6 +96,37 @@ class MultiheadRMSNorm(Module):
95
96
  def forward(self, x):
96
97
  return self.rmsnorm(x) * (self.gamma + 1.)
97
98
 
99
+ # attention pool
100
+
101
+ class AttentionPool(Module):
102
+ def __init__(
103
+ self,
104
+ dim,
105
+ chunk_size
106
+ ):
107
+ """
108
+ taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
109
+ """
110
+ super().__init__()
111
+ self.split_chunks = Rearrange('b (n c) d -> b n c d', c = chunk_size)
112
+ self.to_attn_logits = nn.Linear(dim, dim)
113
+
114
+ # default to average pool
115
+
116
+ nn.init.zeros_(self.to_attn_logits.weight)
117
+ nn.init.zeros_(self.to_attn_logits.bias)
118
+
119
+ def forward(
120
+ self,
121
+ x
122
+ ):
123
+ x = self.split_chunks(x)
124
+ attn_logits = self.to_attn_logits(x)
125
+
126
+ attn = attn_logits.softmax(dim = -2)
127
+
128
+ return reduce(x * attn, 'b n c d -> b n d', 'sum')
129
+
98
130
  # classes
99
131
 
100
132
  class MemoryMLP(Module):
@@ -123,6 +155,46 @@ class MemoryMLP(Module):
123
155
 
124
156
  return x
125
157
 
158
+ # memory mlp, but with gated residual + final projection
159
+
160
+ class GatedResidualMemoryMLP(Module):
161
+ def __init__(
162
+ self,
163
+ dim,
164
+ depth
165
+ ):
166
+ super().__init__()
167
+ self.depth = depth
168
+
169
+ self.weights = ParameterList([
170
+ ParameterList([
171
+ Parameter(torch.randn(dim, dim)),
172
+ Parameter(torch.randn(dim * 2, dim)),
173
+ ]) for _ in range(depth)
174
+ ])
175
+
176
+ self.final_proj = Parameter(torch.randn(dim, dim))
177
+
178
+ for param in self.parameters():
179
+ nn.init.xavier_uniform_(param)
180
+
181
+ def forward(
182
+ self,
183
+ x
184
+ ):
185
+ for weight, to_gates in self.weights:
186
+ res = x
187
+
188
+ x = x @ weight
189
+ x = F.silu(x)
190
+
191
+ # gated residual
192
+
193
+ gates = cat((x, res), dim = -1) @ to_gates
194
+ x = res.lerp(x, gates.sigmoid())
195
+
196
+ return x @ self.final_proj
197
+
126
198
  # memory mlp with factorized weights
127
199
  # so can tradeoff capacity for smaller chunk sizes
128
200
 
@@ -224,6 +296,7 @@ class NeuralMemory(Module):
224
296
  default_step_transform_max_lr = 1e-2,
225
297
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
226
298
  max_mem_layer_modulation = 1e1, # max of 10.
299
+ attn_pool_chunks = False,
227
300
  pre_rmsnorm = True,
228
301
  post_rmsnorm = True,
229
302
  learned_mem_model_weights = True,
@@ -304,10 +377,17 @@ class NeuralMemory(Module):
304
377
  self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
305
378
  nn.init.normal_(self.empty_memory_embed, std = 0.02)
306
379
 
380
+ # whether to use averaging of chunks, or attention pooling
381
+
382
+ if not attn_pool_chunks:
383
+ chunk_reduce_module = Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size)
384
+ else:
385
+ chunk_reduce_module = AttentionPool(dim, chunk_size = chunk_size)
386
+
307
387
  # learned adaptive learning rate and momentum
308
388
 
309
389
  self.to_momentum = Sequential(
310
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
390
+ chunk_reduce_module,
311
391
  LinearNoBias(dim, heads),
312
392
  Rearrange('b n h -> (b h) n 1')
313
393
  )
@@ -325,7 +405,7 @@ class NeuralMemory(Module):
325
405
  # per layer learning rate modulation
326
406
 
327
407
  self.to_layer_modulation = Sequential(
328
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
408
+ chunk_reduce_module,
329
409
  LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
330
410
  Rearrange('b n (h w) -> w (b h) n', h = heads),
331
411
  nn.Sigmoid()
@@ -340,7 +420,7 @@ class NeuralMemory(Module):
340
420
  # weight decay factor
341
421
 
342
422
  self.to_decay_factor = Sequential(
343
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
423
+ chunk_reduce_module,
344
424
  LinearNoBias(dim, heads),
345
425
  Rearrange('b n h -> (b h) n 1')
346
426
  )
@@ -5,10 +5,10 @@ import numpy as np
5
5
 
6
6
  import torch
7
7
  from torch import nn, Tensor
8
- from torch.optim import Adam
9
8
  from torch.nn import functional as F
10
9
  from torch.utils.data import DataLoader, Dataset
11
10
 
11
+ from adam_atan2_pytorch import AdoptAtan2
12
12
  from titans_pytorch import MemoryAsContextTransformer
13
13
 
14
14
  # constants
@@ -34,6 +34,7 @@ NEURAL_MEM_GATE_ATTN_OUTPUT = True
34
34
  WINDOW_SIZE = 32
35
35
  NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
36
36
  SLIDING_WINDOWS = True
37
+ STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
37
38
  KV_RECON_LOSS_WEIGHT = 0.
38
39
  LEARNED_MEM_MODEL_WEIGHTS = True
39
40
 
@@ -86,6 +87,7 @@ model = MemoryAsContextTransformer(
86
87
  neural_memory_kwargs = dict(
87
88
  dim_head = 64,
88
89
  heads = 4,
90
+ attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
89
91
  use_accelerated_scan = USE_ACCELERATED_SCAN,
90
92
  learned_mem_model_weights = LEARNED_MEM_MODEL_WEIGHTS,
91
93
  default_model_kwargs = dict(
@@ -122,11 +124,11 @@ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
122
124
 
123
125
  # optimizer
124
126
 
125
- optim = Adam(model.parameters(), lr=LEARNING_RATE)
127
+ optim = AdoptAtan2(model.parameters(), lr = LEARNING_RATE)
126
128
 
127
129
  # training
128
130
 
129
- for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
131
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
130
132
  model.train()
131
133
 
132
134
  for __ in range(GRADIENT_ACCUMULATE_EVERY):
File without changes
File without changes
File without changes
File without changes