titans-pytorch 0.0.23__tar.gz → 0.0.24__tar.gz

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.23
3
+ Version: 0.0.24
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.23"
3
+ version = "0.0.24"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -73,6 +73,17 @@ def softclamp_grad_norm(t, max_value):
73
73
  t = t * (clamped_norm / norm)
74
74
  return inverse(t)
75
75
 
76
+ # multi head rmsnorm
77
+
78
+ class MultiheadRMSNorm(Module):
79
+ def __init__(self, dim, heads):
80
+ super().__init__()
81
+ self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
82
+ self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
83
+
84
+ def forward(self, x):
85
+ return self.rmsnorm(x) * (self.gamma + 1.)
86
+
76
87
  # classes
77
88
 
78
89
  class MemoryMLP(Module):
@@ -131,20 +142,22 @@ class NeuralMemory(Module):
131
142
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
132
143
  self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
133
144
 
134
- self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
145
+ self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
135
146
 
136
147
  # maybe multi-headed
137
148
 
138
149
  dim_head = default(dim_head, dim)
139
150
  dim_inner = dim_head * heads
140
151
 
152
+ self.heads = heads
153
+
141
154
  self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
142
- self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
155
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
143
156
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
144
157
 
145
158
  self.retrieve_gate = nn.Sequential(
146
159
  LinearNoBias(dim, heads),
147
- Rearrange('b n h -> (b h) n 1'),
160
+ Rearrange('b n h -> b h n 1'),
148
161
  nn.Sigmoid()
149
162
  ) if heads > 1 else None
150
163
 
@@ -364,7 +377,7 @@ class NeuralMemory(Module):
364
377
  past_weights: dict[str, Tensor] | None = None,
365
378
  ):
366
379
  chunk_size = self.chunk_size
367
- seq_len = seq.shape[1]
380
+ batch, seq_len = seq.shape[:2]
368
381
 
369
382
  seq = self.retrieve_norm(seq)
370
383
 
@@ -398,8 +411,6 @@ class NeuralMemory(Module):
398
411
 
399
412
  queries = self.split_heads(queries)
400
413
 
401
- batch = queries.shape[0]
402
-
403
414
  # fetch values from memory model
404
415
 
405
416
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
@@ -411,7 +422,9 @@ class NeuralMemory(Module):
411
422
 
412
423
  # reconstitute batch dimension
413
424
 
414
- values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
425
+ values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads)
426
+
427
+ values = self.multihead_rmsnorm(values)
415
428
 
416
429
  # maybe gate
417
430
 
@@ -424,10 +437,6 @@ class NeuralMemory(Module):
424
437
 
425
438
  values = self.combine_heads(values)
426
439
 
427
- # post norm, somehow could not stabilize this without it, not in paper
428
-
429
- values = self.post_rmsnorm(values)
430
-
431
440
  # restore, pad with empty memory embed
432
441
 
433
442
  empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
@@ -32,7 +32,7 @@ SEQ_LEN = 512
32
32
 
33
33
  PROJECT_NAME = 'titans-neural-memory'
34
34
  WANDB_ONLINE = False # turn this on to pipe experiment to cloud
35
- GLOBAL_LAYERS = (4, 5)
35
+ GLOBAL_LAYERS = (2, 4)
36
36
  USE_TITANS_MEMORY = True
37
37
  NEURAL_MEMORY_DEPTH = 2
38
38
  WINDOW_SIZE = 64
File without changes