titans-pytorch 0.1.1__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.1
3
+ Version: 0.1.5
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.1"
3
+ version = "0.1.5"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -11,12 +11,14 @@ def exists(v):
11
11
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
12
12
  @pytest.mark.parametrize('silu', (False, True))
13
13
  @pytest.mark.parametrize('learned_mem_model_weights', (False, True))
14
+ @pytest.mark.parametrize('attn_pool_chunks', (False, True))
14
15
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
15
16
  @pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
16
17
  def test_titans(
17
18
  seq_len,
18
19
  silu,
19
20
  learned_mem_model_weights,
21
+ attn_pool_chunks,
20
22
  max_grad_norm,
21
23
  per_parameter_lr_modulation
22
24
  ):
@@ -24,6 +26,7 @@ def test_titans(
24
26
  dim = 384,
25
27
  chunk_size = 64,
26
28
  activation = nn.SiLU() if silu else None,
29
+ attn_pool_chunks = attn_pool_chunks,
27
30
  max_grad_norm = max_grad_norm,
28
31
  per_parameter_lr_modulation = per_parameter_lr_modulation,
29
32
  learned_mem_model_weights = learned_mem_model_weights
@@ -50,10 +53,12 @@ def test_titans_attn_memory():
50
53
 
51
54
  assert seq.shape == retrieved.shape
52
55
 
56
+ @pytest.mark.parametrize('seq_len', (1023, 17))
53
57
  @pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
54
58
  @pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
55
59
  @pytest.mark.parametrize('neural_mem_gate_attn_output', (False, True))
56
60
  def test_mac(
61
+ seq_len,
57
62
  num_persist_mem_tokens,
58
63
  num_longterm_mem_tokens,
59
64
  neural_mem_gate_attn_output
@@ -70,13 +75,15 @@ def test_mac(
70
75
  neural_mem_gate_attn_output = neural_mem_gate_attn_output
71
76
  )
72
77
 
73
- x = torch.randint(0, 256, (1, 1023))
78
+ x = torch.randint(0, 256, (1, seq_len))
74
79
 
75
80
  logits = transformer(x)
76
- assert logits.shape == (1, 1023, 256)
81
+ assert logits.shape == (1, seq_len, 256)
77
82
 
83
+ @pytest.mark.parametrize('seq_len', (1023, 17))
78
84
  @pytest.mark.parametrize('sliding', (True, False))
79
85
  def test_flex(
86
+ seq_len,
80
87
  sliding
81
88
  ):
82
89
  if not (torch.cuda.is_available() and exists(flex_attention)):
@@ -91,7 +98,7 @@ def test_flex(
91
98
  sliding = sliding
92
99
  ).cuda()
93
100
 
94
- seq = torch.randn(1, 1019, 512).cuda()
101
+ seq = torch.randn(1, seq_len, 512).cuda()
95
102
 
96
103
  out_flex, _ = attn(seq)
97
104
  out_non_flex, _ = attn(seq, disable_flex_attn = True)
@@ -528,7 +528,8 @@ class MemoryAsContextTransformer(Module):
528
528
  filter_fn: Callable = min_p_filter,
529
529
  filter_kwargs: dict = dict(
530
530
  min_p = 0.1,
531
- )
531
+ ),
532
+ show_progress = True
532
533
  ):
533
534
  was_training = self.training
534
535
  self.eval()
@@ -536,7 +537,9 @@ class MemoryAsContextTransformer(Module):
536
537
  prompt_seq_len, out = prompt.shape[-1], prompt.clone()
537
538
  sample_num_times = max(0, seq_len - prompt_seq_len)
538
539
 
539
- for _ in tqdm.tqdm(range(sample_num_times)):
540
+ iter_wrap = tqdm.tqdm if show_progress else identity
541
+
542
+ for _ in iter_wrap(range(sample_num_times)):
540
543
  logits = self.forward(out, disable_flex_attn = True)
541
544
  logits = logits[:, -1]
542
545
 
@@ -18,7 +18,7 @@ from titans_pytorch.associative_scan import (
18
18
  )
19
19
 
20
20
  import einx
21
- from einops import rearrange, repeat, pack, unpack
21
+ from einops import rearrange, repeat, reduce, pack, unpack
22
22
  from einops.layers.torch import Rearrange, Reduce
23
23
 
24
24
  """
@@ -95,6 +95,37 @@ class MultiheadRMSNorm(Module):
95
95
  def forward(self, x):
96
96
  return self.rmsnorm(x) * (self.gamma + 1.)
97
97
 
98
+ # attention pool
99
+
100
+ class AttentionPool(Module):
101
+ def __init__(
102
+ self,
103
+ dim,
104
+ chunk_size
105
+ ):
106
+ """
107
+ taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
108
+ """
109
+ super().__init__()
110
+ self.split_chunks = Rearrange('b (n c) d -> b n c d', c = chunk_size)
111
+ self.to_attn_logits = nn.Linear(dim, dim)
112
+
113
+ # default to average pool
114
+
115
+ nn.init.zeros_(self.to_attn_logits.weight)
116
+ nn.init.zeros_(self.to_attn_logits.bias)
117
+
118
+ def forward(
119
+ self,
120
+ x
121
+ ):
122
+ x = self.split_chunks(x)
123
+ attn_logits = self.to_attn_logits(x)
124
+
125
+ attn = attn_logits.softmax(dim = -2)
126
+
127
+ return reduce(x * attn, 'b n c d -> b n d', 'sum')
128
+
98
129
  # classes
99
130
 
100
131
  class MemoryMLP(Module):
@@ -224,6 +255,7 @@ class NeuralMemory(Module):
224
255
  default_step_transform_max_lr = 1e-2,
225
256
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
226
257
  max_mem_layer_modulation = 1e1, # max of 10.
258
+ attn_pool_chunks = False,
227
259
  pre_rmsnorm = True,
228
260
  post_rmsnorm = True,
229
261
  learned_mem_model_weights = True,
@@ -304,10 +336,17 @@ class NeuralMemory(Module):
304
336
  self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
305
337
  nn.init.normal_(self.empty_memory_embed, std = 0.02)
306
338
 
339
+ # whether to use averaging of chunks, or attention pooling
340
+
341
+ if not attn_pool_chunks:
342
+ chunk_reduce_module = Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size)
343
+ else:
344
+ chunk_reduce_module = AttentionPool(dim, chunk_size = chunk_size)
345
+
307
346
  # learned adaptive learning rate and momentum
308
347
 
309
348
  self.to_momentum = Sequential(
310
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
349
+ chunk_reduce_module,
311
350
  LinearNoBias(dim, heads),
312
351
  Rearrange('b n h -> (b h) n 1')
313
352
  )
@@ -325,7 +364,7 @@ class NeuralMemory(Module):
325
364
  # per layer learning rate modulation
326
365
 
327
366
  self.to_layer_modulation = Sequential(
328
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
367
+ chunk_reduce_module,
329
368
  LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
330
369
  Rearrange('b n (h w) -> w (b h) n', h = heads),
331
370
  nn.Sigmoid()
@@ -340,7 +379,7 @@ class NeuralMemory(Module):
340
379
  # weight decay factor
341
380
 
342
381
  self.to_decay_factor = Sequential(
343
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
382
+ chunk_reduce_module,
344
383
  LinearNoBias(dim, heads),
345
384
  Rearrange('b n h -> (b h) n 1')
346
385
  )
File without changes
File without changes
File without changes
File without changes