titans-pytorch 0.3.22__py3-none-any.whl → 0.3.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -289,6 +289,8 @@ class NeuralMemory(Module):
289
289
  self.heads = heads
290
290
 
291
291
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
292
+ self.split_kv_heads = Rearrange('b n (h u d) -> b h (n u) d', h = heads, u = num_kv_per_token)
293
+
292
294
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
293
295
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
294
296
 
@@ -425,8 +427,8 @@ class NeuralMemory(Module):
425
427
 
426
428
  self.to_learned_momentum_combine = Sequential(
427
429
  nn.Linear(dim, heads * momentum_order),
428
- nn.Softmax(dim = -1),
429
- Rearrange('b n (h o) -> o (b h) n', h = heads)
430
+ Rearrange('b n (h o) -> o (b h) n', h = heads),
431
+ nn.Softmax(dim = 0),
430
432
  )
431
433
 
432
434
  self.learned_combine_include_zeroth = learned_combine_include_zeroth
@@ -596,17 +598,15 @@ class NeuralMemory(Module):
596
598
 
597
599
  # maybe multi head
598
600
 
599
- keys, values = map(self.split_heads, (keys, values))
600
-
601
- batch = keys.shape[0]
601
+ keys, values = map(self.split_kv_heads, (keys, values))
602
602
 
603
- # take care of chunking
603
+ # maybe keys rmsnorm
604
604
 
605
- keys, values = tuple(rearrange(t, 'b h (n c) (u d) -> (b h n) (c u) d', c = chunk_size, u = num_updates) for t in (keys, values))
605
+ keys = self.k_norm(keys)
606
606
 
607
- # maybe qk rmsnorm
607
+ # take care of chunking
608
608
 
609
- keys = self.k_norm(keys)
609
+ keys, values = tuple(rearrange(t, 'b h (n c u) d -> (b h n) (c u) d', c = chunk_size, u = num_updates) for t in (keys, values))
610
610
 
611
611
  # adaptive lr
612
612
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.22
3
+ Version: 0.3.24
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,29
2
2
  titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
3
  titans_pytorch/mac_transformer.py,sha256=grD327B3OCIy7d23jNUWIoUo1bIgXUqD26dXWCjdi28,25565
4
4
  titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=2NPaesGpSrYoGnp4VshYPm1gvA1W7xcuJzZ2nydCcGY,31161
6
- titans_pytorch-0.3.22.dist-info/METADATA,sha256=kOPNtBdpt8QQMENoPhjE1wqN-zQLrt40QCqtMTatRhQ,6817
7
- titans_pytorch-0.3.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.22.dist-info/RECORD,,
5
+ titans_pytorch/neural_memory.py,sha256=haSepQdfGsQdoo9Yk5agvRR91kTu8kgkXpBmBZaH8WI,31237
6
+ titans_pytorch-0.3.24.dist-info/METADATA,sha256=0-WHTKNXZpESWfOMSOO8MiWddqDoSRP1lifsfgHmewo,6817
7
+ titans_pytorch-0.3.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.24.dist-info/RECORD,,