titans-pytorch 0.0.23__py3-none-any.whl → 0.0.24__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
titans_pytorch/titans.py CHANGED
@@ -73,6 +73,17 @@ def softclamp_grad_norm(t, max_value):
73
73
  t = t * (clamped_norm / norm)
74
74
  return inverse(t)
75
75
 
76
+ # multi head rmsnorm
77
+
78
+ class MultiheadRMSNorm(Module):
79
+ def __init__(self, dim, heads):
80
+ super().__init__()
81
+ self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
82
+ self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
83
+
84
+ def forward(self, x):
85
+ return self.rmsnorm(x) * (self.gamma + 1.)
86
+
76
87
  # classes
77
88
 
78
89
  class MemoryMLP(Module):
@@ -131,20 +142,22 @@ class NeuralMemory(Module):
131
142
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
132
143
  self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
133
144
 
134
- self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
145
+ self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
135
146
 
136
147
  # maybe multi-headed
137
148
 
138
149
  dim_head = default(dim_head, dim)
139
150
  dim_inner = dim_head * heads
140
151
 
152
+ self.heads = heads
153
+
141
154
  self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
142
- self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
155
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
143
156
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
144
157
 
145
158
  self.retrieve_gate = nn.Sequential(
146
159
  LinearNoBias(dim, heads),
147
- Rearrange('b n h -> (b h) n 1'),
160
+ Rearrange('b n h -> b h n 1'),
148
161
  nn.Sigmoid()
149
162
  ) if heads > 1 else None
150
163
 
@@ -364,7 +377,7 @@ class NeuralMemory(Module):
364
377
  past_weights: dict[str, Tensor] | None = None,
365
378
  ):
366
379
  chunk_size = self.chunk_size
367
- seq_len = seq.shape[1]
380
+ batch, seq_len = seq.shape[:2]
368
381
 
369
382
  seq = self.retrieve_norm(seq)
370
383
 
@@ -398,8 +411,6 @@ class NeuralMemory(Module):
398
411
 
399
412
  queries = self.split_heads(queries)
400
413
 
401
- batch = queries.shape[0]
402
-
403
414
  # fetch values from memory model
404
415
 
405
416
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
@@ -411,7 +422,9 @@ class NeuralMemory(Module):
411
422
 
412
423
  # reconstitute batch dimension
413
424
 
414
- values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
425
+ values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads)
426
+
427
+ values = self.multihead_rmsnorm(values)
415
428
 
416
429
  # maybe gate
417
430
 
@@ -424,10 +437,6 @@ class NeuralMemory(Module):
424
437
 
425
438
  values = self.combine_heads(values)
426
439
 
427
- # post norm, somehow could not stabilize this without it, not in paper
428
-
429
- values = self.post_rmsnorm(values)
430
-
431
440
  # restore, pad with empty memory embed
432
441
 
433
442
  empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.23
3
+ Version: 0.0.24
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=3UxRZl_uwQBly11jQAWjfnNzHSoOUKiw-Ux2lXu2ilI,14304
4
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
+ titans_pytorch-0.0.24.dist-info/METADATA,sha256=WGxo4oVx9HCq7LvSH8u_isp1tjxVXb3Ao_GrgjdFzSo,3811
6
+ titans_pytorch-0.0.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.24.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=AU2mf3RkClSRIP0IUwnnqsA5O1udNYGbTRb0lVBLA78,14024
4
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
- titans_pytorch-0.0.23.dist-info/METADATA,sha256=U80_8U_mwaQqwWKPWlu76-O3-CTCfrZ7t_HdE3Zl_qE,3811
6
- titans_pytorch-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.23.dist-info/RECORD,,