titans-pytorch 0.1.5__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -46,6 +46,7 @@ Requires-Dist: torch>=2.2
46
46
  Requires-Dist: tqdm
47
47
  Requires-Dist: x-transformers
48
48
  Provides-Extra: examples
49
+ Requires-Dist: adam-atan2-pytorch>=0.1.18; extra == 'examples'
49
50
  Requires-Dist: wandb; extra == 'examples'
50
51
  Provides-Extra: test
51
52
  Requires-Dist: pytest; extra == 'test'
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.5"
3
+ version = "0.1.6"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -45,6 +45,7 @@ Repository = "https://github.com/lucidrains/titans-pytorch"
45
45
  [project.optional-dependencies]
46
46
 
47
47
  examples = [
48
+ "adam-atan2-pytorch>=0.1.18",
48
49
  "wandb"
49
50
  ]
50
51
 
@@ -2,7 +2,8 @@ from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
4
  MemoryAttention,
5
- FactorizedMemoryMLP
5
+ FactorizedMemoryMLP,
6
+ GatedResidualMemoryMLP
6
7
  )
7
8
 
8
9
  from titans_pytorch.mac_transformer import (
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
  from typing import Callable
3
+
3
4
  import math
4
5
  from functools import partial
5
6
 
6
7
  import torch
7
- from torch import nn, Tensor
8
+ from torch import nn, cat, Tensor
8
9
  import torch.nn.functional as F
9
10
  from torch.nn import Linear, Module, Parameter, ParameterList
10
11
  from torch.func import functional_call, vmap, grad
@@ -154,6 +155,46 @@ class MemoryMLP(Module):
154
155
 
155
156
  return x
156
157
 
158
+ # memory mlp, but with gated residual + final projection
159
+
160
+ class GatedResidualMemoryMLP(Module):
161
+ def __init__(
162
+ self,
163
+ dim,
164
+ depth
165
+ ):
166
+ super().__init__()
167
+ self.depth = depth
168
+
169
+ self.weights = ParameterList([
170
+ ParameterList([
171
+ Parameter(torch.randn(dim, dim)),
172
+ Parameter(torch.randn(dim * 2, dim)),
173
+ ]) for _ in range(depth)
174
+ ])
175
+
176
+ self.final_proj = Parameter(torch.randn(dim, dim))
177
+
178
+ for param in self.parameters():
179
+ nn.init.xavier_uniform_(param)
180
+
181
+ def forward(
182
+ self,
183
+ x
184
+ ):
185
+ for weight, to_gates in self.weights:
186
+ res = x
187
+
188
+ x = x @ weight
189
+ x = F.silu(x)
190
+
191
+ # gated residual
192
+
193
+ gates = cat((x, res), dim = -1) @ to_gates
194
+ x = res.lerp(x, gates.sigmoid())
195
+
196
+ return x @ self.final_proj
197
+
157
198
  # memory mlp with factorized weights
158
199
  # so can tradeoff capacity for smaller chunk sizes
159
200
 
@@ -5,10 +5,10 @@ import numpy as np
5
5
 
6
6
  import torch
7
7
  from torch import nn, Tensor
8
- from torch.optim import Adam
9
8
  from torch.nn import functional as F
10
9
  from torch.utils.data import DataLoader, Dataset
11
10
 
11
+ from adam_atan2_pytorch import AdoptAtan2
12
12
  from titans_pytorch import MemoryAsContextTransformer
13
13
 
14
14
  # constants
@@ -34,6 +34,7 @@ NEURAL_MEM_GATE_ATTN_OUTPUT = True
34
34
  WINDOW_SIZE = 32
35
35
  NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
36
36
  SLIDING_WINDOWS = True
37
+ STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
37
38
  KV_RECON_LOSS_WEIGHT = 0.
38
39
  LEARNED_MEM_MODEL_WEIGHTS = True
39
40
 
@@ -86,6 +87,7 @@ model = MemoryAsContextTransformer(
86
87
  neural_memory_kwargs = dict(
87
88
  dim_head = 64,
88
89
  heads = 4,
90
+ attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
89
91
  use_accelerated_scan = USE_ACCELERATED_SCAN,
90
92
  learned_mem_model_weights = LEARNED_MEM_MODEL_WEIGHTS,
91
93
  default_model_kwargs = dict(
@@ -122,11 +124,11 @@ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
122
124
 
123
125
  # optimizer
124
126
 
125
- optim = Adam(model.parameters(), lr=LEARNING_RATE)
127
+ optim = AdoptAtan2(model.parameters(), lr = LEARNING_RATE)
126
128
 
127
129
  # training
128
130
 
129
- for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
131
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
130
132
  model.train()
131
133
 
132
134
  for __ in range(GRADIENT_ACCUMULATE_EVERY):
File without changes
File without changes
File without changes
File without changes