titans-pytorch 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -73,6 +73,17 @@ def softclamp_grad_norm(t, max_value):
73
73
  t = t * (clamped_norm / norm)
74
74
  return inverse(t)
75
75
 
76
+ # multi head rmsnorm
77
+
78
+ class MultiheadRMSNorm(Module):
79
+ def __init__(self, dim, heads):
80
+ super().__init__()
81
+ self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
82
+ self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
83
+
84
+ def forward(self, x):
85
+ return self.rmsnorm(x) * (self.gamma + 1.)
86
+
76
87
  # classes
77
88
 
78
89
  class MemoryMLP(Module):
@@ -125,26 +136,28 @@ class NeuralMemory(Module):
125
136
  )
126
137
  ):
127
138
  super().__init__()
139
+ dim_head = default(dim_head, dim)
128
140
 
129
141
  # norms
130
142
 
131
143
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
132
144
  self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
133
145
 
134
- self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
146
+ self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
135
147
 
136
148
  # maybe multi-headed
137
149
 
138
- dim_head = default(dim_head, dim)
139
150
  dim_inner = dim_head * heads
140
151
 
152
+ self.heads = heads
153
+
141
154
  self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
142
- self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
155
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
143
156
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
144
157
 
145
158
  self.retrieve_gate = nn.Sequential(
146
159
  LinearNoBias(dim, heads),
147
- Rearrange('b n h -> (b h) n 1'),
160
+ Rearrange('b n h -> b h n 1'),
148
161
  nn.Sigmoid()
149
162
  ) if heads > 1 else None
150
163
 
@@ -364,7 +377,7 @@ class NeuralMemory(Module):
364
377
  past_weights: dict[str, Tensor] | None = None,
365
378
  ):
366
379
  chunk_size = self.chunk_size
367
- seq_len = seq.shape[1]
380
+ batch, seq_len = seq.shape[:2]
368
381
 
369
382
  seq = self.retrieve_norm(seq)
370
383
 
@@ -398,8 +411,6 @@ class NeuralMemory(Module):
398
411
 
399
412
  queries = self.split_heads(queries)
400
413
 
401
- batch = queries.shape[0]
402
-
403
414
  # fetch values from memory model
404
415
 
405
416
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
@@ -411,7 +422,9 @@ class NeuralMemory(Module):
411
422
 
412
423
  # reconstitute batch dimension
413
424
 
414
- values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
425
+ values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads)
426
+
427
+ values = self.multihead_rmsnorm(values)
415
428
 
416
429
  # maybe gate
417
430
 
@@ -424,10 +437,6 @@ class NeuralMemory(Module):
424
437
 
425
438
  values = self.combine_heads(values)
426
439
 
427
- # post norm, somehow could not stabilize this without it, not in paper
428
-
429
- values = self.post_rmsnorm(values)
430
-
431
440
  # restore, pad with empty memory embed
432
441
 
433
442
  empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.23
3
+ Version: 0.0.25
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
4
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
+ titans_pytorch-0.0.25.dist-info/METADATA,sha256=NeRAldZl9fN7bc3YzL44kDzH2rsq5SMBZ-7RjGs_B0g,3811
6
+ titans_pytorch-0.0.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.25.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=AU2mf3RkClSRIP0IUwnnqsA5O1udNYGbTRb0lVBLA78,14024
4
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
- titans_pytorch-0.0.23.dist-info/METADATA,sha256=U80_8U_mwaQqwWKPWlu76-O3-CTCfrZ7t_HdE3Zl_qE,3811
6
- titans_pytorch-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.23.dist-info/RECORD,,