titans-pytorch 0.1.1__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -528,7 +528,8 @@ class MemoryAsContextTransformer(Module):
528
528
  filter_fn: Callable = min_p_filter,
529
529
  filter_kwargs: dict = dict(
530
530
  min_p = 0.1,
531
- )
531
+ ),
532
+ show_progress = True
532
533
  ):
533
534
  was_training = self.training
534
535
  self.eval()
@@ -536,7 +537,9 @@ class MemoryAsContextTransformer(Module):
536
537
  prompt_seq_len, out = prompt.shape[-1], prompt.clone()
537
538
  sample_num_times = max(0, seq_len - prompt_seq_len)
538
539
 
539
- for _ in tqdm.tqdm(range(sample_num_times)):
540
+ iter_wrap = tqdm.tqdm if show_progress else identity
541
+
542
+ for _ in iter_wrap(range(sample_num_times)):
540
543
  logits = self.forward(out, disable_flex_attn = True)
541
544
  logits = logits[:, -1]
542
545
 
titans_pytorch/titans.py CHANGED
@@ -18,7 +18,7 @@ from titans_pytorch.associative_scan import (
18
18
  )
19
19
 
20
20
  import einx
21
- from einops import rearrange, repeat, pack, unpack
21
+ from einops import rearrange, repeat, reduce, pack, unpack
22
22
  from einops.layers.torch import Rearrange, Reduce
23
23
 
24
24
  """
@@ -95,6 +95,37 @@ class MultiheadRMSNorm(Module):
95
95
  def forward(self, x):
96
96
  return self.rmsnorm(x) * (self.gamma + 1.)
97
97
 
98
+ # attention pool
99
+
100
+ class AttentionPool(Module):
101
+ def __init__(
102
+ self,
103
+ dim,
104
+ chunk_size
105
+ ):
106
+ """
107
+ taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
108
+ """
109
+ super().__init__()
110
+ self.split_chunks = Rearrange('b (n c) d -> b n c d', c = chunk_size)
111
+ self.to_attn_logits = nn.Linear(dim, dim)
112
+
113
+ # default to average pool
114
+
115
+ nn.init.zeros_(self.to_attn_logits.weight)
116
+ nn.init.zeros_(self.to_attn_logits.bias)
117
+
118
+ def forward(
119
+ self,
120
+ x
121
+ ):
122
+ x = self.split_chunks(x)
123
+ attn_logits = self.to_attn_logits(x)
124
+
125
+ attn = attn_logits.softmax(dim = -2)
126
+
127
+ return reduce(x * attn, 'b n c d -> b n d', 'sum')
128
+
98
129
  # classes
99
130
 
100
131
  class MemoryMLP(Module):
@@ -224,6 +255,7 @@ class NeuralMemory(Module):
224
255
  default_step_transform_max_lr = 1e-2,
225
256
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
226
257
  max_mem_layer_modulation = 1e1, # max of 10.
258
+ attn_pool_chunks = False,
227
259
  pre_rmsnorm = True,
228
260
  post_rmsnorm = True,
229
261
  learned_mem_model_weights = True,
@@ -304,10 +336,17 @@ class NeuralMemory(Module):
304
336
  self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
305
337
  nn.init.normal_(self.empty_memory_embed, std = 0.02)
306
338
 
339
+ # whether to use averaging of chunks, or attention pooling
340
+
341
+ if not attn_pool_chunks:
342
+ chunk_reduce_module = Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size)
343
+ else:
344
+ chunk_reduce_module = AttentionPool(dim, chunk_size = chunk_size)
345
+
307
346
  # learned adaptive learning rate and momentum
308
347
 
309
348
  self.to_momentum = Sequential(
310
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
349
+ chunk_reduce_module,
311
350
  LinearNoBias(dim, heads),
312
351
  Rearrange('b n h -> (b h) n 1')
313
352
  )
@@ -325,7 +364,7 @@ class NeuralMemory(Module):
325
364
  # per layer learning rate modulation
326
365
 
327
366
  self.to_layer_modulation = Sequential(
328
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
367
+ chunk_reduce_module,
329
368
  LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
330
369
  Rearrange('b n (h w) -> w (b h) n', h = heads),
331
370
  nn.Sigmoid()
@@ -340,7 +379,7 @@ class NeuralMemory(Module):
340
379
  # weight decay factor
341
380
 
342
381
  self.to_decay_factor = Sequential(
343
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
382
+ chunk_reduce_module,
344
383
  LinearNoBias(dim, heads),
345
384
  Rearrange('b n h -> (b h) n 1')
346
385
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.1
3
+ Version: 0.1.5
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
4
+ titans_pytorch/titans.py,sha256=iF0tTTyLs3hPhJDvGVKD2PdXgpWo9xOggD_42szPwjg,19632
5
+ titans_pytorch-0.1.5.dist-info/METADATA,sha256=GrCMbvIDT9gdL8JJ-U55oxFeB8TVRI2PTuvFK2QQjbk,4684
6
+ titans_pytorch-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.5.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=f_m559p9PLW1yxW6tQfrD43v0eMpeTUxpYdtb32UFgg,19031
4
- titans_pytorch/titans.py,sha256=cGWJHkOYmIeE6X383mZvyjusECBwbplVvK0cfgfhBxg,18634
5
- titans_pytorch-0.1.1.dist-info/METADATA,sha256=fxIRTFPGxq9RZMK4yRaAnoG3ym7dfOdaYiatBwmkS6Q,4684
6
- titans_pytorch-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.1.dist-info/RECORD,,