titans-pytorch 0.3.23__tar.gz → 0.3.24__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.23
3
+ Version: 0.3.24
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.3.23"
3
+ version = "0.3.24"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -29,6 +29,7 @@ def torch_default_dtype(dtype):
29
29
  @pytest.mark.parametrize('chunk_size, attn_pool_chunks', ((64, True), (64, False), (1, False)))
30
30
  @pytest.mark.parametrize('momentum', (False, True))
31
31
  @pytest.mark.parametrize('qk_rmsnorm', (False, True))
32
+ @pytest.mark.parametrize('heads', (1, 4))
32
33
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
33
34
  @pytest.mark.parametrize('num_kv_per_token', (1, 2))
34
35
  @pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
@@ -40,6 +41,7 @@ def test_titans(
40
41
  chunk_size,
41
42
  momentum,
42
43
  qk_rmsnorm,
44
+ heads,
43
45
  max_grad_norm,
44
46
  num_kv_per_token,
45
47
  per_parameter_lr_modulation,
@@ -54,6 +56,7 @@ def test_titans(
54
56
  num_kv_per_token = num_kv_per_token,
55
57
  momentum = momentum,
56
58
  qk_rmsnorm = qk_rmsnorm,
59
+ heads = heads,
57
60
  per_parameter_lr_modulation = per_parameter_lr_modulation,
58
61
  per_head_learned_parameters = per_head_learned_parameters
59
62
  )
@@ -289,6 +289,8 @@ class NeuralMemory(Module):
289
289
  self.heads = heads
290
290
 
291
291
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
292
+ self.split_kv_heads = Rearrange('b n (h u d) -> b h (n u) d', h = heads, u = num_kv_per_token)
293
+
292
294
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
293
295
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
294
296
 
@@ -596,17 +598,15 @@ class NeuralMemory(Module):
596
598
 
597
599
  # maybe multi head
598
600
 
599
- keys, values = map(self.split_heads, (keys, values))
600
-
601
- batch = keys.shape[0]
601
+ keys, values = map(self.split_kv_heads, (keys, values))
602
602
 
603
- # take care of chunking
603
+ # maybe keys rmsnorm
604
604
 
605
- keys, values = tuple(rearrange(t, 'b h (n c) (u d) -> (b h n) (c u) d', c = chunk_size, u = num_updates) for t in (keys, values))
605
+ keys = self.k_norm(keys)
606
606
 
607
- # maybe qk rmsnorm
607
+ # take care of chunking
608
608
 
609
- keys = self.k_norm(keys)
609
+ keys, values = tuple(rearrange(t, 'b h (n c u) d -> (b h n) (c u) d', c = chunk_size, u = num_updates) for t in (keys, values))
610
610
 
611
611
  # adaptive lr
612
612
 
File without changes