titans-pytorch 0.0.43__tar.gz → 0.0.44__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/PKG-INFO +1 -1
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/pyproject.toml +1 -1
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/titans_pytorch/mac_transformer.py +11 -9
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/train_mac.py +1 -0
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/.gitignore +0 -0
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/LICENSE +0 -0
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/README.md +0 -0
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/data/README.md +0 -0
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/fig1.png +0 -0
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/fig2.png +0 -0
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.43 → titans_pytorch-0.0.44}/titans_pytorch/titans.py +0 -0
|
@@ -224,7 +224,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
224
224
|
self.layers = ModuleList([])
|
|
225
225
|
|
|
226
226
|
self.neural_mem_layers = ModuleList([])
|
|
227
|
-
neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
|
|
227
|
+
self.neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
|
|
228
228
|
|
|
229
229
|
layers = tuple(range(1, depth + 1))
|
|
230
230
|
|
|
@@ -245,7 +245,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
245
245
|
|
|
246
246
|
mem = NeuralMemory(
|
|
247
247
|
dim = dim,
|
|
248
|
-
chunk_size = neural_memory_segment_len,
|
|
248
|
+
chunk_size = self.neural_memory_segment_len,
|
|
249
249
|
**neural_memory_kwargs
|
|
250
250
|
)
|
|
251
251
|
|
|
@@ -287,10 +287,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
287
287
|
|
|
288
288
|
# math
|
|
289
289
|
|
|
290
|
-
batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
|
|
291
|
-
|
|
292
|
-
windows = ceil(seq_len / segment_len)
|
|
293
|
-
total_segment_len = segment_len + num_longterm_mem_tokens
|
|
290
|
+
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens
|
|
294
291
|
|
|
295
292
|
# token embedding
|
|
296
293
|
|
|
@@ -305,11 +302,16 @@ class MemoryAsContextTransformer(Module):
|
|
|
305
302
|
|
|
306
303
|
x = inverse_segment(x)
|
|
307
304
|
|
|
305
|
+
seq_len_with_mem = x.shape[-2]
|
|
306
|
+
|
|
308
307
|
# apply axial positional embedding
|
|
309
308
|
# so intra and inter segment can be more easily discerned by the network
|
|
310
309
|
|
|
311
|
-
|
|
312
|
-
|
|
310
|
+
neural_mem_windows = ceil(seq_len_with_mem / neural_mem_segment_len)
|
|
311
|
+
|
|
312
|
+
pos_emb = self.axial_pos_emb((neural_mem_windows, neural_mem_segment_len), flatten = True)
|
|
313
|
+
|
|
314
|
+
x = x + pos_emb[:seq_len_with_mem]
|
|
313
315
|
|
|
314
316
|
# value residual
|
|
315
317
|
|
|
@@ -334,7 +336,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
334
336
|
|
|
335
337
|
# excise out the memories
|
|
336
338
|
|
|
337
|
-
x, inverse_segment = pad_and_segment_with_inverse(x,
|
|
339
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len + num_longterm_mem_tokens)
|
|
338
340
|
|
|
339
341
|
x, _ = unpack(x, mem_ps, 'b * d')
|
|
340
342
|
|
|
@@ -62,6 +62,7 @@ model = MemoryAsContextTransformer(
|
|
|
62
62
|
num_persist_mem_tokens = NUM_PERSIST_MEM,
|
|
63
63
|
num_longterm_mem_tokens = NUM_LONGTERM_MEM,
|
|
64
64
|
neural_memory_layers = NEURAL_MEM_LAYERS,
|
|
65
|
+
neural_memory_segment_len = WINDOW_SIZE // 2,
|
|
65
66
|
neural_memory_kwargs = dict(
|
|
66
67
|
dim_head = 64,
|
|
67
68
|
heads = 4,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|