titans-pytorch 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.0"
3
+ version = "0.2.1"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -491,7 +491,8 @@ class MemoryAsContextTransformer(Module):
491
491
  aux_kv_recon_loss_weight = 0.,
492
492
  use_flex_attn = False,
493
493
  sliding_window_attn = False,
494
- weight_tie_memory_model = False
494
+ weight_tie_memory_model = False,
495
+ prev_neural_mem_update_for_weights = None
495
496
  ):
496
497
  super().__init__()
497
498
 
@@ -533,6 +534,7 @@ class MemoryAsContextTransformer(Module):
533
534
  assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
534
535
 
535
536
  self.weight_tie_memory_model = weight_tie_memory_model
537
+ self.prev_neural_mem_update_for_weights = default(prev_neural_mem_update_for_weights, weight_tie_memory_model)
536
538
 
537
539
  # value residual learning for neural memory
538
540
 
@@ -702,7 +704,7 @@ class MemoryAsContextTransformer(Module):
702
704
 
703
705
  # math
704
706
 
705
- batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, weight_tie_memory_model = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.weight_tie_memory_model
707
+ batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, prev_neural_mem_update_for_weights = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.prev_neural_mem_update_for_weights
706
708
 
707
709
  seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
708
710
 
@@ -814,7 +816,7 @@ class MemoryAsContextTransformer(Module):
814
816
  if self.mem_add_value_residual:
815
817
  mem_value_residual = next_mem_value_residual
816
818
 
817
- if weight_tie_memory_model:
819
+ if prev_neural_mem_update_for_weights:
818
820
  neural_memory_updates = next_neural_mem_cache.updates
819
821
 
820
822
  if self.gate_attn_output:
@@ -164,6 +164,6 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
164
164
  prime = decode_tokens(inp)
165
165
  print(f'%s \n\n %s', (prime, '*' * 100))
166
166
 
167
- sample = model.sample(inp[None, ...], GENERATE_LENGTH, use_cache = True)
167
+ sample = model.sample(inp[None, ...], GENERATE_LENGTH, use_cache = USE_FAST_INFERENCE)
168
168
  output_str = decode_tokens(sample[0])
169
169
  print(output_str)
File without changes
File without changes
File without changes
File without changes