titans-pytorch 0.2.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/PKG-INFO +1 -1
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/pyproject.toml +1 -1
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/titans_pytorch/mac_transformer.py +5 -3
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/train_mac.py +1 -1
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/.gitignore +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/LICENSE +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/README.md +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/data/README.md +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/fig1.png +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/fig2.png +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/tests/test_titans.py +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.1}/titans_pytorch/neural_memory.py +0 -0
@@ -491,7 +491,8 @@ class MemoryAsContextTransformer(Module):
|
|
491
491
|
aux_kv_recon_loss_weight = 0.,
|
492
492
|
use_flex_attn = False,
|
493
493
|
sliding_window_attn = False,
|
494
|
-
weight_tie_memory_model = False
|
494
|
+
weight_tie_memory_model = False,
|
495
|
+
prev_neural_mem_update_for_weights = None
|
495
496
|
):
|
496
497
|
super().__init__()
|
497
498
|
|
@@ -533,6 +534,7 @@ class MemoryAsContextTransformer(Module):
|
|
533
534
|
assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
|
534
535
|
|
535
536
|
self.weight_tie_memory_model = weight_tie_memory_model
|
537
|
+
self.prev_neural_mem_update_for_weights = default(prev_neural_mem_update_for_weights, weight_tie_memory_model)
|
536
538
|
|
537
539
|
# value residual learning for neural memory
|
538
540
|
|
@@ -702,7 +704,7 @@ class MemoryAsContextTransformer(Module):
|
|
702
704
|
|
703
705
|
# math
|
704
706
|
|
705
|
-
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size,
|
707
|
+
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, prev_neural_mem_update_for_weights = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.prev_neural_mem_update_for_weights
|
706
708
|
|
707
709
|
seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
|
708
710
|
|
@@ -814,7 +816,7 @@ class MemoryAsContextTransformer(Module):
|
|
814
816
|
if self.mem_add_value_residual:
|
815
817
|
mem_value_residual = next_mem_value_residual
|
816
818
|
|
817
|
-
if
|
819
|
+
if prev_neural_mem_update_for_weights:
|
818
820
|
neural_memory_updates = next_neural_mem_cache.updates
|
819
821
|
|
820
822
|
if self.gate_attn_output:
|
@@ -164,6 +164,6 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
|
|
164
164
|
prime = decode_tokens(inp)
|
165
165
|
print(f'%s \n\n %s', (prime, '*' * 100))
|
166
166
|
|
167
|
-
sample = model.sample(inp[None, ...], GENERATE_LENGTH, use_cache =
|
167
|
+
sample = model.sample(inp[None, ...], GENERATE_LENGTH, use_cache = USE_FAST_INFERENCE)
|
168
168
|
output_str = decode_tokens(sample[0])
|
169
169
|
print(output_str)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|