titans-pytorch 0.2.18__py3-none-any.whl → 0.2.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -560,7 +560,7 @@ class MemoryAsContextTransformer(Module):
560
560
  chunk_size = self.neural_memory_segment_len,
561
561
  batch_size = neural_memory_batch_size,
562
562
  model = deepcopy(neural_memory_model),
563
- learned_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
563
+ accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
564
564
  **neural_memory_kwargs
565
565
  )
566
566
 
@@ -66,6 +66,12 @@ def xnor(x, y):
66
66
  def divisible_by(num, den):
67
67
  return (num % den) == 0
68
68
 
69
+ def tuple_index_set(t: tuple, index, value):
70
+ klass = type(t)
71
+ t = list(t)
72
+ t[index] = value
73
+ return klass(*t)
74
+
69
75
  def safe_cat(inputs, dim = -2):
70
76
  inputs = tuple(filter(exists, inputs))
71
77
 
@@ -296,7 +302,7 @@ class NeuralMemory(Module):
296
302
  init_adaptive_step_bias = None,
297
303
  init_momentum_bias = None,
298
304
  init_decay_bias = None,
299
- learned_weight_residual = False,
305
+ accept_weight_residual = False,
300
306
  default_model_kwargs: dict = dict(
301
307
  depth = 2
302
308
  )
@@ -445,7 +451,7 @@ class NeuralMemory(Module):
445
451
  nn.Linear(dim, heads),
446
452
  Rearrange('b n h -> b h n'),
447
453
  nn.Sigmoid()
448
- ) if learned_weight_residual else None
454
+ ) if accept_weight_residual else None
449
455
 
450
456
  # allow for softclamp the gradient norms for storing memories
451
457
 
@@ -584,7 +590,6 @@ class NeuralMemory(Module):
584
590
  if exists(self.to_learned_weight_residual_mix):
585
591
  mix = self.to_learned_weight_residual_mix(chunked_seq)
586
592
  mix = rearrange(mix, 'b h n -> (b h) n')
587
-
588
593
  prev_weights = prev_weights.apply(lambda t: einx.multiply('bh n, bh n ... -> bh n ...', mix, t))
589
594
 
590
595
  weights_for_surprise = weights_for_surprise + prev_weights
@@ -854,9 +859,9 @@ class NeuralMemory(Module):
854
859
 
855
860
  weights = last_update
856
861
 
857
- next_neural_mem_state = list(next_neural_mem_state)
858
- next_neural_mem_state[1] = last_update
859
- next_neural_mem_state = NeuralMemCache(*next_neural_mem_state)
862
+ next_neural_mem_state = tuple_index_set(next_neural_mem_state, 1, last_update)
863
+
864
+ next_neural_mem_state = tuple_index_set(next_neural_mem_state, -1, updates)
860
865
 
861
866
  # retrieve
862
867
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.18
3
+ Version: 0.2.20
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -204,6 +204,6 @@ $ python train_mac.py
204
204
  eprint = {2501.12352},
205
205
  archivePrefix = {arXiv},
206
206
  primaryClass = {cs.LG},
207
- url = {https://arxiv.org/abs/2501.12352},
207
+ url = {https://arxiv.org/abs/2501.12352},
208
208
  }
209
209
  ```
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
+ titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
+ titans_pytorch/neural_memory.py,sha256=e1TENsK0IptpapFRtQux-Uii2MFvisoiQLzTQboHc50,26709
6
+ titans_pytorch-0.2.20.dist-info/METADATA,sha256=8v5puM3CwbPaYDt2aZaqWruPj8whCt_moiQUVE1vk5U,6816
7
+ titans_pytorch-0.2.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.20.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=J9MHFViWWehBOFOKx7ry_X2k8nAXaAZFFeCwGtudZyk,24942
4
- titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
- titans_pytorch/neural_memory.py,sha256=A74v2g4OUXJIqXLu1Rk7FDsTlhveJIHsiamsWN9e7cI,26602
6
- titans_pytorch-0.2.18.dist-info/METADATA,sha256=QFJI4OoZZSdgZd6FJInAG5-z2ubpvO5I0aCRNKwhDnQ,6812
7
- titans_pytorch-0.2.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.18.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.18.dist-info/RECORD,,