titans-pytorch 0.2.6__tar.gz → 0.2.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.6
3
+ Version: 0.2.7
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -27,7 +27,7 @@ License: MIT License
27
27
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
28
  SOFTWARE.
29
29
  License-File: LICENSE
30
- Keywords: artificial intelligence,deep learning,linear attention,neural memory module,test time training
30
+ Keywords: artificial intelligence,deep learning,linear attention,memory,test time training
31
31
  Classifier: Development Status :: 4 - Beta
32
32
  Classifier: Intended Audience :: Developers
33
33
  Classifier: License :: OSI Approved :: MIT License
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.6"
3
+ version = "0.2.7"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -11,9 +11,9 @@ license = { file = "LICENSE" }
11
11
  keywords = [
12
12
  'artificial intelligence',
13
13
  'deep learning',
14
- 'neural memory module',
15
14
  'test time training',
16
- 'linear attention'
15
+ 'linear attention',
16
+ 'memory',
17
17
  ]
18
18
 
19
19
  classifiers=[
@@ -24,9 +24,9 @@ def torch_default_dtype(dtype):
24
24
 
25
25
  # main test
26
26
 
27
- @pytest.mark.parametrize('seq_len', (32, 1024, 77))
27
+ @pytest.mark.parametrize('seq_len', (32, 512, 77))
28
28
  @pytest.mark.parametrize('silu', (False, True))
29
- @pytest.mark.parametrize('attn_pool_chunks', (False, True))
29
+ @pytest.mark.parametrize('chunk_size, attn_pool_chunks', ((64, True), (64, False), (1, False)))
30
30
  @pytest.mark.parametrize('momentum', (False, True))
31
31
  @pytest.mark.parametrize('qk_rmsnorm', (False, True))
32
32
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
@@ -35,6 +35,7 @@ def test_titans(
35
35
  seq_len,
36
36
  silu,
37
37
  attn_pool_chunks,
38
+ chunk_size,
38
39
  momentum,
39
40
  qk_rmsnorm,
40
41
  max_grad_norm,
@@ -42,7 +43,7 @@ def test_titans(
42
43
  ):
43
44
  mem = NeuralMemory(
44
45
  dim = 384,
45
- chunk_size = 64,
46
+ chunk_size = chunk_size,
46
47
  activation = nn.SiLU() if silu else None,
47
48
  attn_pool_chunks = attn_pool_chunks,
48
49
  max_grad_norm = max_grad_norm,
@@ -51,6 +51,9 @@ def default(*args):
51
51
  return arg
52
52
  return None
53
53
 
54
+ def identity(t):
55
+ return t
56
+
54
57
  def xnor(x, y):
55
58
  return not (x ^ y)
56
59
 
@@ -64,9 +67,6 @@ def safe_cat(inputs, dim = -2):
64
67
 
65
68
  return cat(inputs, dim = dim)
66
69
 
67
- def identity(t):
68
- return t
69
-
70
70
  def dict_get_shape(td):
71
71
  return {k: v.shape for k, v in td.items()}
72
72
 
@@ -454,14 +454,14 @@ class NeuralMemory(Module):
454
454
 
455
455
  weights = TensorDict(weights)
456
456
 
457
- # allow for neural memory of a previous layer and the past to produce gradients that become the weights of the current one generating the surprise
458
- # think this is necessary otherwise the memory model is static (unless if paper is misunderstood)
459
- # improvise (or perhaps correcting to) a solution
457
+ # allow for neural memory of a previous layer to influence surprise of current layer
458
+
459
+ weights_for_surprise = weights
460
460
 
461
461
  if exists(prev_layer_updates):
462
462
  prev_layer_updates = TensorDict(prev_layer_updates)
463
463
 
464
- weights = weights + prev_layer_updates
464
+ weights_for_surprise = weights_for_surprise + prev_layer_updates
465
465
 
466
466
  per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
467
467
 
@@ -506,11 +506,11 @@ class NeuralMemory(Module):
506
506
  # flatten batch and time if surprise depends on previous layer memory model
507
507
 
508
508
  if exists(prev_layer_updates):
509
- weights = weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
509
+ weights_for_surprise = weights_for_surprise.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
510
510
 
511
511
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
512
512
 
513
- grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
513
+ grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
514
514
 
515
515
  grads = TensorDict(grads)
516
516
 
@@ -536,7 +536,15 @@ class NeuralMemory(Module):
536
536
 
537
537
  if not exists(past_state):
538
538
  empty_dict = {key: None for key in weights.keys()}
539
- past_state = (empty_dict, empty_dict)
539
+
540
+ # minibatch_init_weight corresponds to W0 in figure 7 of TTT paper
541
+
542
+ minibatch_init_weight = weights
543
+
544
+ if dict_get_shape(weights) == self.init_weight_shape:
545
+ minibatch_init_weight = weights.apply(lambda t: repeat(t, '... -> b 1 (...)', b = batch * heads))
546
+
547
+ past_state = (minibatch_init_weight, empty_dict)
540
548
 
541
549
  past_last_update, past_last_momentum = past_state
542
550
 
@@ -734,7 +742,7 @@ class NeuralMemory(Module):
734
742
 
735
743
  # retrieve
736
744
 
737
- retrieved = self.retrieve_memories(token, updates + weights, chunk_size = 1)
745
+ retrieved = self.retrieve_memories(token, weights, chunk_size = 1)
738
746
 
739
747
  # next state tuple
740
748
 
@@ -796,7 +804,7 @@ class NeuralMemory(Module):
796
804
 
797
805
  retrieved = self.retrieve_memories(
798
806
  seq,
799
- mem_model_weights + updates,
807
+ mem_model_weights,
800
808
  chunk_size = chunk_size,
801
809
  prev_layer_updates = prev_layer_updates
802
810
  )
File without changes
File without changes
File without changes
File without changes