titans-pytorch 0.2.18__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -66,6 +66,12 @@ def xnor(x, y):
66
66
  def divisible_by(num, den):
67
67
  return (num % den) == 0
68
68
 
69
+ def tuple_index_set(t: tuple, index, value):
70
+ klass = type(t)
71
+ t = list(t)
72
+ t[index] = value
73
+ return klass(*t)
74
+
69
75
  def safe_cat(inputs, dim = -2):
70
76
  inputs = tuple(filter(exists, inputs))
71
77
 
@@ -296,7 +302,7 @@ class NeuralMemory(Module):
296
302
  init_adaptive_step_bias = None,
297
303
  init_momentum_bias = None,
298
304
  init_decay_bias = None,
299
- learned_weight_residual = False,
305
+ accept_weight_residual = False,
300
306
  default_model_kwargs: dict = dict(
301
307
  depth = 2
302
308
  )
@@ -445,7 +451,7 @@ class NeuralMemory(Module):
445
451
  nn.Linear(dim, heads),
446
452
  Rearrange('b n h -> b h n'),
447
453
  nn.Sigmoid()
448
- ) if learned_weight_residual else None
454
+ ) if accept_weight_residual else None
449
455
 
450
456
  # allow for softclamp the gradient norms for storing memories
451
457
 
@@ -584,7 +590,6 @@ class NeuralMemory(Module):
584
590
  if exists(self.to_learned_weight_residual_mix):
585
591
  mix = self.to_learned_weight_residual_mix(chunked_seq)
586
592
  mix = rearrange(mix, 'b h n -> (b h) n')
587
-
588
593
  prev_weights = prev_weights.apply(lambda t: einx.multiply('bh n, bh n ... -> bh n ...', mix, t))
589
594
 
590
595
  weights_for_surprise = weights_for_surprise + prev_weights
@@ -854,9 +859,9 @@ class NeuralMemory(Module):
854
859
 
855
860
  weights = last_update
856
861
 
857
- next_neural_mem_state = list(next_neural_mem_state)
858
- next_neural_mem_state[1] = last_update
859
- next_neural_mem_state = NeuralMemCache(*next_neural_mem_state)
862
+ next_neural_mem_state = tuple_index_set(next_neural_mem_state, 1, last_update)
863
+
864
+ next_neural_mem_state = tuple_index_set(next_neural_mem_state, -1, updates)
860
865
 
861
866
  # retrieve
862
867
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.18
3
+ Version: 0.2.19
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,27
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
3
  titans_pytorch/mac_transformer.py,sha256=J9MHFViWWehBOFOKx7ry_X2k8nAXaAZFFeCwGtudZyk,24942
4
4
  titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
- titans_pytorch/neural_memory.py,sha256=A74v2g4OUXJIqXLu1Rk7FDsTlhveJIHsiamsWN9e7cI,26602
6
- titans_pytorch-0.2.18.dist-info/METADATA,sha256=QFJI4OoZZSdgZd6FJInAG5-z2ubpvO5I0aCRNKwhDnQ,6812
7
- titans_pytorch-0.2.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.18.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.18.dist-info/RECORD,,
5
+ titans_pytorch/neural_memory.py,sha256=e1TENsK0IptpapFRtQux-Uii2MFvisoiQLzTQboHc50,26709
6
+ titans_pytorch-0.2.19.dist-info/METADATA,sha256=6frwjwq4CX7VcdCMh8Op9WeDHLRD-oeo8DlTde538UI,6812
7
+ titans_pytorch-0.2.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.19.dist-info/RECORD,,