titans-pytorch 0.0.65__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +55 -6
- titans_pytorch/titans.py +38 -8
- {titans_pytorch-0.0.65.dist-info → titans_pytorch-0.1.1.dist-info}/METADATA +13 -3
- titans_pytorch-0.1.1.dist-info/RECORD +8 -0
- titans_pytorch-0.0.65.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.65.dist-info → titans_pytorch-0.1.1.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.65.dist-info → titans_pytorch-0.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,6 +3,8 @@ from typing import Callable
|
|
|
3
3
|
from math import ceil
|
|
4
4
|
from functools import partial
|
|
5
5
|
|
|
6
|
+
import tqdm
|
|
7
|
+
|
|
6
8
|
import torch
|
|
7
9
|
from torch import nn, cat
|
|
8
10
|
import torch.nn.functional as F
|
|
@@ -88,12 +90,6 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
|
|
|
88
90
|
|
|
89
91
|
def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
|
|
90
92
|
batch, seq_len = seq.shape[:2]
|
|
91
|
-
|
|
92
|
-
need_segment = seq_len >= segment_len
|
|
93
|
-
|
|
94
|
-
if not need_segment:
|
|
95
|
-
return seq, identity
|
|
96
|
-
|
|
97
93
|
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
|
98
94
|
|
|
99
95
|
padding = next_seq_len_mult - seq_len
|
|
@@ -116,6 +112,29 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
|
|
|
116
112
|
|
|
117
113
|
return seq, inverse
|
|
118
114
|
|
|
115
|
+
# sampling related
|
|
116
|
+
|
|
117
|
+
def log(t, eps = 1e-20):
|
|
118
|
+
return torch.log(t.clamp(min = eps))
|
|
119
|
+
|
|
120
|
+
def gumbel_noise(t):
|
|
121
|
+
noise = torch.rand_like(t)
|
|
122
|
+
return -log(-log(noise))
|
|
123
|
+
|
|
124
|
+
def gumbel_sample(t, temperature = 1.):
|
|
125
|
+
if temperature > 0.:
|
|
126
|
+
t = t / temperature + gumbel_noise(t)
|
|
127
|
+
return t.argmax(dim = -1, keepdim = True)
|
|
128
|
+
|
|
129
|
+
# min_p
|
|
130
|
+
# https://arxiv.org/abs/2407.01082
|
|
131
|
+
|
|
132
|
+
def min_p_filter(logits, min_p = 0.1):
|
|
133
|
+
probs = logits.softmax(dim = -1)
|
|
134
|
+
max_probs = probs.amax(dim = -1, keepdim = True)
|
|
135
|
+
limit = min_p * max_probs
|
|
136
|
+
return torch.where(probs < limit, float('-inf'), logits)
|
|
137
|
+
|
|
119
138
|
# feedforward and attention
|
|
120
139
|
|
|
121
140
|
class GEGLU(Module):
|
|
@@ -500,6 +519,36 @@ class MemoryAsContextTransformer(Module):
|
|
|
500
519
|
self.segment_len = segment_len
|
|
501
520
|
self.num_persist_mem_tokens = num_persist_mem_tokens
|
|
502
521
|
|
|
522
|
+
@torch.no_grad()
|
|
523
|
+
def sample(
|
|
524
|
+
self,
|
|
525
|
+
prompt: Tensor,
|
|
526
|
+
seq_len: int,
|
|
527
|
+
temperature = 1.5,
|
|
528
|
+
filter_fn: Callable = min_p_filter,
|
|
529
|
+
filter_kwargs: dict = dict(
|
|
530
|
+
min_p = 0.1,
|
|
531
|
+
)
|
|
532
|
+
):
|
|
533
|
+
was_training = self.training
|
|
534
|
+
self.eval()
|
|
535
|
+
|
|
536
|
+
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
|
|
537
|
+
sample_num_times = max(0, seq_len - prompt_seq_len)
|
|
538
|
+
|
|
539
|
+
for _ in tqdm.tqdm(range(sample_num_times)):
|
|
540
|
+
logits = self.forward(out, disable_flex_attn = True)
|
|
541
|
+
logits = logits[:, -1]
|
|
542
|
+
|
|
543
|
+
logits = filter_fn(logits, **filter_kwargs)
|
|
544
|
+
sample = gumbel_sample(logits, temperature = temperature)
|
|
545
|
+
|
|
546
|
+
out = torch.cat((out, sample), dim = -1)
|
|
547
|
+
|
|
548
|
+
self.train(was_training)
|
|
549
|
+
|
|
550
|
+
return out[..., prompt_seq_len:]
|
|
551
|
+
|
|
503
552
|
def forward(
|
|
504
553
|
self,
|
|
505
554
|
x,
|
titans_pytorch/titans.py
CHANGED
|
@@ -17,6 +17,7 @@ from titans_pytorch.associative_scan import (
|
|
|
17
17
|
pad_at_dim
|
|
18
18
|
)
|
|
19
19
|
|
|
20
|
+
import einx
|
|
20
21
|
from einops import rearrange, repeat, pack, unpack
|
|
21
22
|
from einops.layers.torch import Rearrange, Reduce
|
|
22
23
|
|
|
@@ -26,6 +27,7 @@ b - batch
|
|
|
26
27
|
n - sequence
|
|
27
28
|
d - feature dimension
|
|
28
29
|
c - intra-chunk
|
|
30
|
+
w - num memory network weight parameters
|
|
29
31
|
"""
|
|
30
32
|
|
|
31
33
|
LinearNoBias = partial(Linear, bias = False)
|
|
@@ -220,6 +222,8 @@ class NeuralMemory(Module):
|
|
|
220
222
|
store_memory_loss_fn: Callable = default_loss_fn,
|
|
221
223
|
adaptive_step_transform: Callable | None = None,
|
|
222
224
|
default_step_transform_max_lr = 1e-2,
|
|
225
|
+
per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
|
|
226
|
+
max_mem_layer_modulation = 1e1, # max of 10.
|
|
223
227
|
pre_rmsnorm = True,
|
|
224
228
|
post_rmsnorm = True,
|
|
225
229
|
learned_mem_model_weights = True,
|
|
@@ -250,7 +254,7 @@ class NeuralMemory(Module):
|
|
|
250
254
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
251
255
|
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
|
252
256
|
|
|
253
|
-
self.retrieve_gate =
|
|
257
|
+
self.retrieve_gate = Sequential(
|
|
254
258
|
LinearNoBias(dim, heads),
|
|
255
259
|
Rearrange('b n h -> b h n 1'),
|
|
256
260
|
nn.Sigmoid()
|
|
@@ -270,6 +274,8 @@ class NeuralMemory(Module):
|
|
|
270
274
|
|
|
271
275
|
self.memory_model = model
|
|
272
276
|
|
|
277
|
+
self.num_memory_parameter_tensors = len(set(model.parameters()))
|
|
278
|
+
|
|
273
279
|
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
|
274
280
|
|
|
275
281
|
self.chunk_size = chunk_size
|
|
@@ -299,15 +305,14 @@ class NeuralMemory(Module):
|
|
|
299
305
|
nn.init.normal_(self.empty_memory_embed, std = 0.02)
|
|
300
306
|
|
|
301
307
|
# learned adaptive learning rate and momentum
|
|
302
|
-
# todo - explore mlp layerwise learned lr / momentum
|
|
303
308
|
|
|
304
|
-
self.to_momentum =
|
|
309
|
+
self.to_momentum = Sequential(
|
|
305
310
|
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
306
311
|
LinearNoBias(dim, heads),
|
|
307
312
|
Rearrange('b n h -> (b h) n 1')
|
|
308
313
|
)
|
|
309
314
|
|
|
310
|
-
self.to_adaptive_step =
|
|
315
|
+
self.to_adaptive_step = Sequential(
|
|
311
316
|
LinearNoBias(dim, heads),
|
|
312
317
|
Rearrange('b n h -> (b h) n')
|
|
313
318
|
)
|
|
@@ -317,13 +322,24 @@ class NeuralMemory(Module):
|
|
|
317
322
|
|
|
318
323
|
self.adaptive_step_transform = adaptive_step_transform
|
|
319
324
|
|
|
325
|
+
# per layer learning rate modulation
|
|
326
|
+
|
|
327
|
+
self.to_layer_modulation = Sequential(
|
|
328
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
329
|
+
LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
|
|
330
|
+
Rearrange('b n (h w) -> w (b h) n', h = heads),
|
|
331
|
+
nn.Sigmoid()
|
|
332
|
+
) if per_parameter_lr_modulation else None
|
|
333
|
+
|
|
334
|
+
self.max_mem_layer_modulation = max_mem_layer_modulation
|
|
335
|
+
|
|
320
336
|
# allow for softclamp the gradient norms for storing memories
|
|
321
337
|
|
|
322
338
|
self.max_grad_norm = max_grad_norm
|
|
323
339
|
|
|
324
340
|
# weight decay factor
|
|
325
341
|
|
|
326
|
-
self.to_decay_factor =
|
|
342
|
+
self.to_decay_factor = Sequential(
|
|
327
343
|
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
328
344
|
LinearNoBias(dim, heads),
|
|
329
345
|
Rearrange('b n h -> (b h) n 1')
|
|
@@ -387,6 +403,11 @@ class NeuralMemory(Module):
|
|
|
387
403
|
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
|
388
404
|
decay_factor = self.to_decay_factor(seq).sigmoid()
|
|
389
405
|
|
|
406
|
+
need_layer_lr_mod = exists(self.to_layer_modulation)
|
|
407
|
+
|
|
408
|
+
if need_layer_lr_mod:
|
|
409
|
+
layer_lr_mod = self.to_layer_modulation(seq) * self.max_mem_layer_modulation
|
|
410
|
+
|
|
390
411
|
# keys and values
|
|
391
412
|
|
|
392
413
|
keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
|
|
@@ -418,6 +439,11 @@ class NeuralMemory(Module):
|
|
|
418
439
|
|
|
419
440
|
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
|
420
441
|
|
|
442
|
+
# maybe per layer modulation
|
|
443
|
+
|
|
444
|
+
if need_layer_lr_mod:
|
|
445
|
+
grads = TensorDict({name: einx.multiply('b h, b h ... -> b h ...', layer_lr_mod, t) for layer_lr_mod, (name, t) in zip(layer_lr_mod, grads.items())})
|
|
446
|
+
|
|
421
447
|
# negative gradients, adaptive lr already applied as loss weight
|
|
422
448
|
|
|
423
449
|
surprises = grads.apply(lambda t: -t)
|
|
@@ -469,7 +495,7 @@ class NeuralMemory(Module):
|
|
|
469
495
|
|
|
470
496
|
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
|
471
497
|
|
|
472
|
-
update = scan_fn(1. - decay_factor, momentum)
|
|
498
|
+
update = scan_fn(1. - decay_factor, momentum)
|
|
473
499
|
|
|
474
500
|
updates[param_name] = inverse_pack(update)
|
|
475
501
|
next_momentum[param_name] = inverse_pack(momentum)
|
|
@@ -566,7 +592,12 @@ class NeuralMemory(Module):
|
|
|
566
592
|
batch, seq_len = seq.shape[:2]
|
|
567
593
|
|
|
568
594
|
if seq_len < self.chunk_size:
|
|
569
|
-
|
|
595
|
+
out = self.init_empty_memory_embed(batch, seq_len)
|
|
596
|
+
|
|
597
|
+
if not return_aux_kv_loss:
|
|
598
|
+
return out
|
|
599
|
+
|
|
600
|
+
return out, self.zero
|
|
570
601
|
|
|
571
602
|
if exists(past_state):
|
|
572
603
|
past_state = tuple(TensorDict(d) for d in past_state)
|
|
@@ -580,7 +611,6 @@ class NeuralMemory(Module):
|
|
|
580
611
|
|
|
581
612
|
past_weights, _ = past_state
|
|
582
613
|
|
|
583
|
-
|
|
584
614
|
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
|
585
615
|
|
|
586
616
|
if not return_aux_kv_loss:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.1.1
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -37,14 +37,15 @@ Requires-Python: >=3.9
|
|
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
|
38
38
|
Requires-Dist: axial-positional-embedding>=0.3.5
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
|
+
Requires-Dist: einx>=0.3.0
|
|
40
41
|
Requires-Dist: hyper-connections>=0.1.8
|
|
41
42
|
Requires-Dist: ninja
|
|
42
43
|
Requires-Dist: rotary-embedding-torch
|
|
43
44
|
Requires-Dist: tensordict
|
|
44
45
|
Requires-Dist: torch>=2.2
|
|
46
|
+
Requires-Dist: tqdm
|
|
45
47
|
Requires-Dist: x-transformers
|
|
46
48
|
Provides-Extra: examples
|
|
47
|
-
Requires-Dist: tqdm; extra == 'examples'
|
|
48
49
|
Requires-Dist: wandb; extra == 'examples'
|
|
49
50
|
Provides-Extra: test
|
|
50
51
|
Requires-Dist: pytest; extra == 'test'
|
|
@@ -58,6 +59,10 @@ Description-Content-Type: text/markdown
|
|
|
58
59
|
|
|
59
60
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
|
60
61
|
|
|
62
|
+
## Appreciation
|
|
63
|
+
|
|
64
|
+
- [@sentialx](https://github.com/sentialx) for sharing his early experimental results with me
|
|
65
|
+
|
|
61
66
|
## Install
|
|
62
67
|
|
|
63
68
|
```bash
|
|
@@ -99,7 +104,12 @@ transformer = MemoryAsContextTransformer(
|
|
|
99
104
|
|
|
100
105
|
token_ids = torch.randint(0, 256, (1, 1023))
|
|
101
106
|
|
|
102
|
-
|
|
107
|
+
loss = transformer(token_ids, return_loss = True) # (1, 1023, 256)
|
|
108
|
+
loss.backward()
|
|
109
|
+
|
|
110
|
+
# after much training
|
|
111
|
+
|
|
112
|
+
sampled = transformer.sample(token_ids[:, :4], 512)
|
|
103
113
|
```
|
|
104
114
|
|
|
105
115
|
## Experiments
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=f_m559p9PLW1yxW6tQfrD43v0eMpeTUxpYdtb32UFgg,19031
|
|
4
|
+
titans_pytorch/titans.py,sha256=cGWJHkOYmIeE6X383mZvyjusECBwbplVvK0cfgfhBxg,18634
|
|
5
|
+
titans_pytorch-0.1.1.dist-info/METADATA,sha256=fxIRTFPGxq9RZMK4yRaAnoG3ym7dfOdaYiatBwmkS6Q,4684
|
|
6
|
+
titans_pytorch-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.1.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.1.1.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=7PHBCbeB1LhHY5s3zAyYF0L3Mm7CNy4TOBbcpLX6bNE,17686
|
|
4
|
-
titans_pytorch/titans.py,sha256=y6lJRErIoM6T2aTVFlf1GxSB0cpsmBZdSIj1DCHUCQ8,17486
|
|
5
|
-
titans_pytorch-0.0.65.dist-info/METADATA,sha256=oDjEiufwOninsFDoCGbu691LXc1mey2OT7j6PNzkz0Q,4457
|
|
6
|
-
titans_pytorch-0.0.65.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.65.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.65.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|