titans-pytorch 0.0.18__tar.gz → 0.0.19__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.18
3
+ Version: 0.0.19
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.18"
3
+ version = "0.0.19"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,7 +1,11 @@
1
1
  import torch
2
2
  import pytest
3
3
 
4
- def test_titans():
4
+ @pytest.mark.parametrize('seq_len', (32, 1024, 77))
5
+ def test_titans(
6
+ seq_len
7
+ ):
8
+
5
9
  from titans_pytorch import NeuralMemory
6
10
 
7
11
  mem = NeuralMemory(
@@ -9,7 +13,7 @@ def test_titans():
9
13
  chunk_size = 64,
10
14
  )
11
15
 
12
- seq = torch.randn(2, 1024, 384)
16
+ seq = torch.randn(2, seq_len, 384)
13
17
  retrieved = mem(seq)
14
18
 
15
19
  assert seq.shape == retrieved.shape
@@ -17,7 +17,7 @@ from titans_pytorch.associative_scan import (
17
17
  )
18
18
 
19
19
  import einx
20
- from einops import rearrange, pack, unpack
20
+ from einops import rearrange, repeat, pack, unpack
21
21
  from einops.layers.torch import Rearrange, Reduce
22
22
 
23
23
  """
@@ -152,6 +152,11 @@ class NeuralMemory(Module):
152
152
  self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
153
153
  self.store_memory_loss_fn = store_memory_loss_fn
154
154
 
155
+ # empty memory embed
156
+
157
+ self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
158
+ nn.init.normal_(self.empty_memory_embed, std = 0.02)
159
+
155
160
  # learned adaptive learning rate and momentum
156
161
  # todo - explore mlp layerwise learned lr / momentum
157
162
 
@@ -187,6 +192,9 @@ class NeuralMemory(Module):
187
192
 
188
193
  return init_weights, init_momentum
189
194
 
195
+ def init_empty_memory_embed(self, batch, seq_len):
196
+ return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
197
+
190
198
  def store_memories(
191
199
  self,
192
200
  seq,
@@ -372,11 +380,12 @@ class NeuralMemory(Module):
372
380
 
373
381
  values = self.post_rmsnorm(values)
374
382
 
375
- # restore
383
+ # restore, pad with empty memory embed
376
384
 
377
- values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
378
- values = values[:, :-padding]
385
+ empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
386
+ values = torch.cat((empty_memory_embeds, values), dim = -2)
379
387
 
388
+ values = values[:, :-padding]
380
389
  return values
381
390
 
382
391
  def forward(
@@ -389,7 +398,7 @@ class NeuralMemory(Module):
389
398
  batch, seq_len = seq.shape[:2]
390
399
 
391
400
  if seq_len < self.chunk_size:
392
- return torch.zeros_like(seq)
401
+ return self.init_empty_memory_embed(batch, seq_len)
393
402
 
394
403
  if exists(past_state):
395
404
  past_state = tuple(TensorDict(d) for d in past_state)
File without changes