titans-pytorch 0.0.17__tar.gz → 0.0.19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.17
3
+ Version: 0.0.19
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.17"
3
+ version = "0.0.19"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,7 +1,11 @@
1
1
  import torch
2
2
  import pytest
3
3
 
4
- def test_titans():
4
+ @pytest.mark.parametrize('seq_len', (32, 1024, 77))
5
+ def test_titans(
6
+ seq_len
7
+ ):
8
+
5
9
  from titans_pytorch import NeuralMemory
6
10
 
7
11
  mem = NeuralMemory(
@@ -9,7 +13,7 @@ def test_titans():
9
13
  chunk_size = 64,
10
14
  )
11
15
 
12
- seq = torch.randn(2, 1024, 384)
16
+ seq = torch.randn(2, seq_len, 384)
13
17
  retrieved = mem(seq)
14
18
 
15
19
  assert seq.shape == retrieved.shape
@@ -17,7 +17,7 @@ from titans_pytorch.associative_scan import (
17
17
  )
18
18
 
19
19
  import einx
20
- from einops import rearrange, pack, unpack
20
+ from einops import rearrange, repeat, pack, unpack
21
21
  from einops.layers.torch import Rearrange, Reduce
22
22
 
23
23
  """
@@ -152,6 +152,11 @@ class NeuralMemory(Module):
152
152
  self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
153
153
  self.store_memory_loss_fn = store_memory_loss_fn
154
154
 
155
+ # empty memory embed
156
+
157
+ self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
158
+ nn.init.normal_(self.empty_memory_embed, std = 0.02)
159
+
155
160
  # learned adaptive learning rate and momentum
156
161
  # todo - explore mlp layerwise learned lr / momentum
157
162
 
@@ -187,6 +192,9 @@ class NeuralMemory(Module):
187
192
 
188
193
  return init_weights, init_momentum
189
194
 
195
+ def init_empty_memory_embed(self, batch, seq_len):
196
+ return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
197
+
190
198
  def store_memories(
191
199
  self,
192
200
  seq,
@@ -269,7 +277,7 @@ class NeuralMemory(Module):
269
277
  gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
270
278
  inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
271
279
 
272
- outputs = scan(gates, inputs)
280
+ outputs = scan(gates.contiguous(), inputs.contiguous())
273
281
 
274
282
  outputs = outputs[..., :seq_len]
275
283
  outputs = rearrange(outputs, 'b d n -> b n d')
@@ -372,11 +380,12 @@ class NeuralMemory(Module):
372
380
 
373
381
  values = self.post_rmsnorm(values)
374
382
 
375
- # restore
383
+ # restore, pad with empty memory embed
376
384
 
377
- values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
378
- values = values[:, :-padding]
385
+ empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
386
+ values = torch.cat((empty_memory_embeds, values), dim = -2)
379
387
 
388
+ values = values[:, :-padding]
380
389
  return values
381
390
 
382
391
  def forward(
@@ -389,7 +398,7 @@ class NeuralMemory(Module):
389
398
  batch, seq_len = seq.shape[:2]
390
399
 
391
400
  if seq_len < self.chunk_size:
392
- return torch.zeros_like(seq)
401
+ return self.init_empty_memory_embed(batch, seq_len)
393
402
 
394
403
  if exists(past_state):
395
404
  past_state = tuple(TensorDict(d) for d in past_state)
@@ -15,7 +15,6 @@ from taylor_series_linear_attention import TaylorSeriesLinearAttn
15
15
 
16
16
  from titans_pytorch.titans import (
17
17
  NeuralMemory,
18
- MemoryAttention,
19
18
  MemoryMLP
20
19
  )
21
20
 
File without changes