titans-pytorch 0.1.2__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,8 @@ from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
4
  MemoryAttention,
5
- FactorizedMemoryMLP
5
+ FactorizedMemoryMLP,
6
+ GatedResidualMemoryMLP
6
7
  )
7
8
 
8
9
  from titans_pytorch.mac_transformer import (
titans_pytorch/titans.py CHANGED
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
  from typing import Callable
3
+
3
4
  import math
4
5
  from functools import partial
5
6
 
6
7
  import torch
7
- from torch import nn, Tensor
8
+ from torch import nn, cat, Tensor
8
9
  import torch.nn.functional as F
9
10
  from torch.nn import Linear, Module, Parameter, ParameterList
10
11
  from torch.func import functional_call, vmap, grad
@@ -18,7 +19,7 @@ from titans_pytorch.associative_scan import (
18
19
  )
19
20
 
20
21
  import einx
21
- from einops import rearrange, repeat, pack, unpack
22
+ from einops import rearrange, repeat, reduce, pack, unpack
22
23
  from einops.layers.torch import Rearrange, Reduce
23
24
 
24
25
  """
@@ -95,6 +96,37 @@ class MultiheadRMSNorm(Module):
95
96
  def forward(self, x):
96
97
  return self.rmsnorm(x) * (self.gamma + 1.)
97
98
 
99
+ # attention pool
100
+
101
+ class AttentionPool(Module):
102
+ def __init__(
103
+ self,
104
+ dim,
105
+ chunk_size
106
+ ):
107
+ """
108
+ taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
109
+ """
110
+ super().__init__()
111
+ self.split_chunks = Rearrange('b (n c) d -> b n c d', c = chunk_size)
112
+ self.to_attn_logits = nn.Linear(dim, dim)
113
+
114
+ # default to average pool
115
+
116
+ nn.init.zeros_(self.to_attn_logits.weight)
117
+ nn.init.zeros_(self.to_attn_logits.bias)
118
+
119
+ def forward(
120
+ self,
121
+ x
122
+ ):
123
+ x = self.split_chunks(x)
124
+ attn_logits = self.to_attn_logits(x)
125
+
126
+ attn = attn_logits.softmax(dim = -2)
127
+
128
+ return reduce(x * attn, 'b n c d -> b n d', 'sum')
129
+
98
130
  # classes
99
131
 
100
132
  class MemoryMLP(Module):
@@ -123,6 +155,46 @@ class MemoryMLP(Module):
123
155
 
124
156
  return x
125
157
 
158
+ # memory mlp, but with gated residual + final projection
159
+
160
+ class GatedResidualMemoryMLP(Module):
161
+ def __init__(
162
+ self,
163
+ dim,
164
+ depth
165
+ ):
166
+ super().__init__()
167
+ self.depth = depth
168
+
169
+ self.weights = ParameterList([
170
+ ParameterList([
171
+ Parameter(torch.randn(dim, dim)),
172
+ Parameter(torch.randn(dim * 2, dim)),
173
+ ]) for _ in range(depth)
174
+ ])
175
+
176
+ self.final_proj = Parameter(torch.randn(dim, dim))
177
+
178
+ for param in self.parameters():
179
+ nn.init.xavier_uniform_(param)
180
+
181
+ def forward(
182
+ self,
183
+ x
184
+ ):
185
+ for weight, to_gates in self.weights:
186
+ res = x
187
+
188
+ x = x @ weight
189
+ x = F.silu(x)
190
+
191
+ # gated residual
192
+
193
+ gates = cat((x, res), dim = -1) @ to_gates
194
+ x = res.lerp(x, gates.sigmoid())
195
+
196
+ return x @ self.final_proj
197
+
126
198
  # memory mlp with factorized weights
127
199
  # so can tradeoff capacity for smaller chunk sizes
128
200
 
@@ -224,6 +296,7 @@ class NeuralMemory(Module):
224
296
  default_step_transform_max_lr = 1e-2,
225
297
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
226
298
  max_mem_layer_modulation = 1e1, # max of 10.
299
+ attn_pool_chunks = False,
227
300
  pre_rmsnorm = True,
228
301
  post_rmsnorm = True,
229
302
  learned_mem_model_weights = True,
@@ -304,10 +377,17 @@ class NeuralMemory(Module):
304
377
  self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
305
378
  nn.init.normal_(self.empty_memory_embed, std = 0.02)
306
379
 
380
+ # whether to use averaging of chunks, or attention pooling
381
+
382
+ if not attn_pool_chunks:
383
+ chunk_reduce_module = Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size)
384
+ else:
385
+ chunk_reduce_module = AttentionPool(dim, chunk_size = chunk_size)
386
+
307
387
  # learned adaptive learning rate and momentum
308
388
 
309
389
  self.to_momentum = Sequential(
310
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
390
+ chunk_reduce_module,
311
391
  LinearNoBias(dim, heads),
312
392
  Rearrange('b n h -> (b h) n 1')
313
393
  )
@@ -325,7 +405,7 @@ class NeuralMemory(Module):
325
405
  # per layer learning rate modulation
326
406
 
327
407
  self.to_layer_modulation = Sequential(
328
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
408
+ chunk_reduce_module,
329
409
  LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
330
410
  Rearrange('b n (h w) -> w (b h) n', h = heads),
331
411
  nn.Sigmoid()
@@ -340,7 +420,7 @@ class NeuralMemory(Module):
340
420
  # weight decay factor
341
421
 
342
422
  self.to_decay_factor = Sequential(
343
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
423
+ chunk_reduce_module,
344
424
  LinearNoBias(dim, heads),
345
425
  Rearrange('b n h -> (b h) n 1')
346
426
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.2
3
+ Version: 0.1.6
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -46,6 +46,7 @@ Requires-Dist: torch>=2.2
46
46
  Requires-Dist: tqdm
47
47
  Requires-Dist: x-transformers
48
48
  Provides-Extra: examples
49
+ Requires-Dist: adam-atan2-pytorch>=0.1.18; extra == 'examples'
49
50
  Requires-Dist: wandb; extra == 'examples'
50
51
  Provides-Extra: test
51
52
  Requires-Dist: pytest; extra == 'test'
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
4
+ titans_pytorch/titans.py,sha256=VMcPcKsoR3G13Um62Aa1HbdwrrV60ljPhP-yF40x90I,20555
5
+ titans_pytorch-0.1.6.dist-info/METADATA,sha256=LJW26WfT9WB-0NfokLLHhcRpWnt76jwkXMt_FSTI3SM,4747
6
+ titans_pytorch-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.6.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
4
- titans_pytorch/titans.py,sha256=cGWJHkOYmIeE6X383mZvyjusECBwbplVvK0cfgfhBxg,18634
5
- titans_pytorch-0.1.2.dist-info/METADATA,sha256=FWq5JIp1WY9dYpzatfGzfkcAGQFk-mEPwxF0wCrbM5w,4684
6
- titans_pytorch-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.2.dist-info/RECORD,,