titans-pytorch 0.1.2__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.2
3
+ Version: 0.1.5
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.2"
3
+ version = "0.1.5"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -11,12 +11,14 @@ def exists(v):
11
11
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
12
12
  @pytest.mark.parametrize('silu', (False, True))
13
13
  @pytest.mark.parametrize('learned_mem_model_weights', (False, True))
14
+ @pytest.mark.parametrize('attn_pool_chunks', (False, True))
14
15
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
15
16
  @pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
16
17
  def test_titans(
17
18
  seq_len,
18
19
  silu,
19
20
  learned_mem_model_weights,
21
+ attn_pool_chunks,
20
22
  max_grad_norm,
21
23
  per_parameter_lr_modulation
22
24
  ):
@@ -24,6 +26,7 @@ def test_titans(
24
26
  dim = 384,
25
27
  chunk_size = 64,
26
28
  activation = nn.SiLU() if silu else None,
29
+ attn_pool_chunks = attn_pool_chunks,
27
30
  max_grad_norm = max_grad_norm,
28
31
  per_parameter_lr_modulation = per_parameter_lr_modulation,
29
32
  learned_mem_model_weights = learned_mem_model_weights
@@ -18,7 +18,7 @@ from titans_pytorch.associative_scan import (
18
18
  )
19
19
 
20
20
  import einx
21
- from einops import rearrange, repeat, pack, unpack
21
+ from einops import rearrange, repeat, reduce, pack, unpack
22
22
  from einops.layers.torch import Rearrange, Reduce
23
23
 
24
24
  """
@@ -95,6 +95,37 @@ class MultiheadRMSNorm(Module):
95
95
  def forward(self, x):
96
96
  return self.rmsnorm(x) * (self.gamma + 1.)
97
97
 
98
+ # attention pool
99
+
100
+ class AttentionPool(Module):
101
+ def __init__(
102
+ self,
103
+ dim,
104
+ chunk_size
105
+ ):
106
+ """
107
+ taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
108
+ """
109
+ super().__init__()
110
+ self.split_chunks = Rearrange('b (n c) d -> b n c d', c = chunk_size)
111
+ self.to_attn_logits = nn.Linear(dim, dim)
112
+
113
+ # default to average pool
114
+
115
+ nn.init.zeros_(self.to_attn_logits.weight)
116
+ nn.init.zeros_(self.to_attn_logits.bias)
117
+
118
+ def forward(
119
+ self,
120
+ x
121
+ ):
122
+ x = self.split_chunks(x)
123
+ attn_logits = self.to_attn_logits(x)
124
+
125
+ attn = attn_logits.softmax(dim = -2)
126
+
127
+ return reduce(x * attn, 'b n c d -> b n d', 'sum')
128
+
98
129
  # classes
99
130
 
100
131
  class MemoryMLP(Module):
@@ -224,6 +255,7 @@ class NeuralMemory(Module):
224
255
  default_step_transform_max_lr = 1e-2,
225
256
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
226
257
  max_mem_layer_modulation = 1e1, # max of 10.
258
+ attn_pool_chunks = False,
227
259
  pre_rmsnorm = True,
228
260
  post_rmsnorm = True,
229
261
  learned_mem_model_weights = True,
@@ -304,10 +336,17 @@ class NeuralMemory(Module):
304
336
  self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
305
337
  nn.init.normal_(self.empty_memory_embed, std = 0.02)
306
338
 
339
+ # whether to use averaging of chunks, or attention pooling
340
+
341
+ if not attn_pool_chunks:
342
+ chunk_reduce_module = Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size)
343
+ else:
344
+ chunk_reduce_module = AttentionPool(dim, chunk_size = chunk_size)
345
+
307
346
  # learned adaptive learning rate and momentum
308
347
 
309
348
  self.to_momentum = Sequential(
310
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
349
+ chunk_reduce_module,
311
350
  LinearNoBias(dim, heads),
312
351
  Rearrange('b n h -> (b h) n 1')
313
352
  )
@@ -325,7 +364,7 @@ class NeuralMemory(Module):
325
364
  # per layer learning rate modulation
326
365
 
327
366
  self.to_layer_modulation = Sequential(
328
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
367
+ chunk_reduce_module,
329
368
  LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
330
369
  Rearrange('b n (h w) -> w (b h) n', h = heads),
331
370
  nn.Sigmoid()
@@ -340,7 +379,7 @@ class NeuralMemory(Module):
340
379
  # weight decay factor
341
380
 
342
381
  self.to_decay_factor = Sequential(
343
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
382
+ chunk_reduce_module,
344
383
  LinearNoBias(dim, heads),
345
384
  Rearrange('b n h -> (b h) n 1')
346
385
  )
File without changes
File without changes
File without changes
File without changes