titans-pytorch 0.0.39__tar.gz → 0.0.40__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/PKG-INFO +1 -1
  2. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/pyproject.toml +1 -1
  3. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/titans_pytorch/mac_transformer.py +32 -3
  4. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/train_mac.py +1 -1
  5. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/.github/workflows/python-publish.yml +0 -0
  6. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/.github/workflows/test.yaml +0 -0
  7. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/.gitignore +0 -0
  8. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/LICENSE +0 -0
  9. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/README.md +0 -0
  10. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/data/README.md +0 -0
  11. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/data/enwik8.gz +0 -0
  12. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/fig1.png +0 -0
  13. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/fig2.png +0 -0
  14. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/requirements.txt +0 -0
  15. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/tests/test_titans.py +0 -0
  16. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/titans_pytorch/__init__.py +0 -0
  17. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/titans_pytorch/associative_scan.py +0 -0
  18. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/titans_pytorch/titans.py +0 -0
  19. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/titans_pytorch/titans_attn_memory.py +0 -0
  20. {titans_pytorch-0.0.39 → titans_pytorch-0.0.40}/train.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.39
3
+ Version: 0.0.40
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.39"
3
+ version = "0.0.40"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -94,6 +94,7 @@ class SegmentedAttention(Module):
94
94
  num_longterm_mem_tokens = 0,
95
95
  dim_head = 64,
96
96
  heads = 8,
97
+ accept_value_residual = False,
97
98
  attend_kwargs: dict = dict()
98
99
  ):
99
100
  super().__init__()
@@ -108,6 +109,12 @@ class SegmentedAttention(Module):
108
109
  self.to_qkv = LinearNoBias(dim, dim_inner * 3)
109
110
  self.to_out = LinearNoBias(dim_inner, dim)
110
111
 
112
+ self.to_learned_v_mix = nn.Sequential(
113
+ nn.Linear(dim, heads),
114
+ Rearrange('b n h -> b h n 1'),
115
+ nn.Sigmoid()
116
+ ) if accept_value_residual else None
117
+
111
118
  self.segment_len = segment_len
112
119
  self.num_longterm_mem_tokens = num_longterm_mem_tokens
113
120
 
@@ -118,7 +125,13 @@ class SegmentedAttention(Module):
118
125
 
119
126
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
120
127
 
121
- def forward(self, seq):
128
+ def forward(
129
+ self,
130
+ seq,
131
+ value_residual = None
132
+ ):
133
+ assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
134
+
122
135
  segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
123
136
  total_segment_len = segment_len + num_longterm_mem_tokens
124
137
 
@@ -136,6 +149,14 @@ class SegmentedAttention(Module):
136
149
  q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
137
150
  q, k, v = map(self.split_heads, (q, k, v))
138
151
 
152
+ # value residual
153
+
154
+ orig_v = v
155
+
156
+ if exists(self.to_learned_v_mix):
157
+ mix = self.to_learned_v_mix(seq)
158
+ v = v.lerp(value_residual, mix)
159
+
139
160
  # take care of persistent memory key / values
140
161
 
141
162
  pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
@@ -159,7 +180,7 @@ class SegmentedAttention(Module):
159
180
 
160
181
  out = inverse_segment(out)
161
182
 
162
- return out
183
+ return out, orig_v
163
184
 
164
185
  # MAC transformer
165
186
 
@@ -210,6 +231,7 @@ class MemoryAsContextTransformer(Module):
210
231
  assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
211
232
 
212
233
  for layer in layers:
234
+ is_first = layer == 1
213
235
 
214
236
  # neural memory
215
237
 
@@ -235,6 +257,7 @@ class MemoryAsContextTransformer(Module):
235
257
  dim_head = dim_head,
236
258
  heads = heads,
237
259
  segment_len = segment_len,
260
+ accept_value_residual = not is_first,
238
261
  num_longterm_mem_tokens = num_longterm_mem_tokens,
239
262
  num_persist_mem_tokens = num_persist_mem_tokens
240
263
  )
@@ -285,6 +308,10 @@ class MemoryAsContextTransformer(Module):
285
308
  pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
286
309
  x = x + pos_emb[:x.shape[-2]]
287
310
 
311
+ # value residual
312
+
313
+ value_residual = None
314
+
288
315
  # expand and reduce streams for hyper connections
289
316
 
290
317
  x = self.expand_streams(x)
@@ -295,7 +322,9 @@ class MemoryAsContextTransformer(Module):
295
322
  x = maybe_neural_mem(x)
296
323
 
297
324
 
298
- x = attn(x)
325
+ x, values = attn(x, value_residual = value_residual)
326
+
327
+ value_residual = default(value_residual, values)
299
328
 
300
329
  x = ff(x)
301
330
 
@@ -24,7 +24,7 @@ SHOULD_GENERATE = False
24
24
  SEQ_LEN = 512
25
25
 
26
26
  PROJECT_NAME = 'titans-mac-transformer'
27
- WANDB_ONLINE = True # turn this on to pipe experiment to cloud
27
+ WANDB_ONLINE = False # turn this on to pipe experiment to cloud
28
28
  NEURAL_MEMORY_DEPTH = 2
29
29
  NUM_PERSIST_MEM = 4
30
30
  NUM_LONGTERM_MEM = 4
File without changes