titans-pytorch 0.0.23__tar.gz → 0.0.24__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.23
3
+ Version: 0.0.24
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.23"
3
+ version = "0.0.24"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -73,6 +73,17 @@ def softclamp_grad_norm(t, max_value):
73
73
  t = t * (clamped_norm / norm)
74
74
  return inverse(t)
75
75
 
76
+ # multi head rmsnorm
77
+
78
+ class MultiheadRMSNorm(Module):
79
+ def __init__(self, dim, heads):
80
+ super().__init__()
81
+ self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
82
+ self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
83
+
84
+ def forward(self, x):
85
+ return self.rmsnorm(x) * (self.gamma + 1.)
86
+
76
87
  # classes
77
88
 
78
89
  class MemoryMLP(Module):
@@ -131,20 +142,22 @@ class NeuralMemory(Module):
131
142
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
132
143
  self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
133
144
 
134
- self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
145
+ self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
135
146
 
136
147
  # maybe multi-headed
137
148
 
138
149
  dim_head = default(dim_head, dim)
139
150
  dim_inner = dim_head * heads
140
151
 
152
+ self.heads = heads
153
+
141
154
  self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
142
- self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
155
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
143
156
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
144
157
 
145
158
  self.retrieve_gate = nn.Sequential(
146
159
  LinearNoBias(dim, heads),
147
- Rearrange('b n h -> (b h) n 1'),
160
+ Rearrange('b n h -> b h n 1'),
148
161
  nn.Sigmoid()
149
162
  ) if heads > 1 else None
150
163
 
@@ -364,7 +377,7 @@ class NeuralMemory(Module):
364
377
  past_weights: dict[str, Tensor] | None = None,
365
378
  ):
366
379
  chunk_size = self.chunk_size
367
- seq_len = seq.shape[1]
380
+ batch, seq_len = seq.shape[:2]
368
381
 
369
382
  seq = self.retrieve_norm(seq)
370
383
 
@@ -398,8 +411,6 @@ class NeuralMemory(Module):
398
411
 
399
412
  queries = self.split_heads(queries)
400
413
 
401
- batch = queries.shape[0]
402
-
403
414
  # fetch values from memory model
404
415
 
405
416
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
@@ -411,7 +422,9 @@ class NeuralMemory(Module):
411
422
 
412
423
  # reconstitute batch dimension
413
424
 
414
- values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
425
+ values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads)
426
+
427
+ values = self.multihead_rmsnorm(values)
415
428
 
416
429
  # maybe gate
417
430
 
@@ -424,10 +437,6 @@ class NeuralMemory(Module):
424
437
 
425
438
  values = self.combine_heads(values)
426
439
 
427
- # post norm, somehow could not stabilize this without it, not in paper
428
-
429
- values = self.post_rmsnorm(values)
430
-
431
440
  # restore, pad with empty memory embed
432
441
 
433
442
  empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
@@ -32,7 +32,7 @@ SEQ_LEN = 512
32
32
 
33
33
  PROJECT_NAME = 'titans-neural-memory'
34
34
  WANDB_ONLINE = False # turn this on to pipe experiment to cloud
35
- GLOBAL_LAYERS = (4, 5)
35
+ GLOBAL_LAYERS = (2, 4)
36
36
  USE_TITANS_MEMORY = True
37
37
  NEURAL_MEMORY_DEPTH = 2
38
38
  WINDOW_SIZE = 64
File without changes