titans-pytorch 0.0.65__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,8 @@ from typing import Callable
3
3
  from math import ceil
4
4
  from functools import partial
5
5
 
6
+ import tqdm
7
+
6
8
  import torch
7
9
  from torch import nn, cat
8
10
  import torch.nn.functional as F
@@ -88,12 +90,6 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
88
90
 
89
91
  def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
90
92
  batch, seq_len = seq.shape[:2]
91
-
92
- need_segment = seq_len >= segment_len
93
-
94
- if not need_segment:
95
- return seq, identity
96
-
97
93
  next_seq_len_mult = round_up_multiple(seq_len, segment_len)
98
94
 
99
95
  padding = next_seq_len_mult - seq_len
@@ -116,6 +112,29 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
116
112
 
117
113
  return seq, inverse
118
114
 
115
+ # sampling related
116
+
117
+ def log(t, eps = 1e-20):
118
+ return torch.log(t.clamp(min = eps))
119
+
120
+ def gumbel_noise(t):
121
+ noise = torch.rand_like(t)
122
+ return -log(-log(noise))
123
+
124
+ def gumbel_sample(t, temperature = 1.):
125
+ if temperature > 0.:
126
+ t = t / temperature + gumbel_noise(t)
127
+ return t.argmax(dim = -1, keepdim = True)
128
+
129
+ # min_p
130
+ # https://arxiv.org/abs/2407.01082
131
+
132
+ def min_p_filter(logits, min_p = 0.1):
133
+ probs = logits.softmax(dim = -1)
134
+ max_probs = probs.amax(dim = -1, keepdim = True)
135
+ limit = min_p * max_probs
136
+ return torch.where(probs < limit, float('-inf'), logits)
137
+
119
138
  # feedforward and attention
120
139
 
121
140
  class GEGLU(Module):
@@ -500,6 +519,36 @@ class MemoryAsContextTransformer(Module):
500
519
  self.segment_len = segment_len
501
520
  self.num_persist_mem_tokens = num_persist_mem_tokens
502
521
 
522
+ @torch.no_grad()
523
+ def sample(
524
+ self,
525
+ prompt: Tensor,
526
+ seq_len: int,
527
+ temperature = 1.5,
528
+ filter_fn: Callable = min_p_filter,
529
+ filter_kwargs: dict = dict(
530
+ min_p = 0.1,
531
+ )
532
+ ):
533
+ was_training = self.training
534
+ self.eval()
535
+
536
+ prompt_seq_len, out = prompt.shape[-1], prompt.clone()
537
+ sample_num_times = max(0, seq_len - prompt_seq_len)
538
+
539
+ for _ in tqdm.tqdm(range(sample_num_times)):
540
+ logits = self.forward(out, disable_flex_attn = True)
541
+ logits = logits[:, -1]
542
+
543
+ logits = filter_fn(logits, **filter_kwargs)
544
+ sample = gumbel_sample(logits, temperature = temperature)
545
+
546
+ out = torch.cat((out, sample), dim = -1)
547
+
548
+ self.train(was_training)
549
+
550
+ return out[..., prompt_seq_len:]
551
+
503
552
  def forward(
504
553
  self,
505
554
  x,
titans_pytorch/titans.py CHANGED
@@ -17,6 +17,7 @@ from titans_pytorch.associative_scan import (
17
17
  pad_at_dim
18
18
  )
19
19
 
20
+ import einx
20
21
  from einops import rearrange, repeat, pack, unpack
21
22
  from einops.layers.torch import Rearrange, Reduce
22
23
 
@@ -26,6 +27,7 @@ b - batch
26
27
  n - sequence
27
28
  d - feature dimension
28
29
  c - intra-chunk
30
+ w - num memory network weight parameters
29
31
  """
30
32
 
31
33
  LinearNoBias = partial(Linear, bias = False)
@@ -220,6 +222,8 @@ class NeuralMemory(Module):
220
222
  store_memory_loss_fn: Callable = default_loss_fn,
221
223
  adaptive_step_transform: Callable | None = None,
222
224
  default_step_transform_max_lr = 1e-2,
225
+ per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
226
+ max_mem_layer_modulation = 1e1, # max of 10.
223
227
  pre_rmsnorm = True,
224
228
  post_rmsnorm = True,
225
229
  learned_mem_model_weights = True,
@@ -250,7 +254,7 @@ class NeuralMemory(Module):
250
254
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
251
255
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
252
256
 
253
- self.retrieve_gate = nn.Sequential(
257
+ self.retrieve_gate = Sequential(
254
258
  LinearNoBias(dim, heads),
255
259
  Rearrange('b n h -> b h n 1'),
256
260
  nn.Sigmoid()
@@ -270,6 +274,8 @@ class NeuralMemory(Module):
270
274
 
271
275
  self.memory_model = model
272
276
 
277
+ self.num_memory_parameter_tensors = len(set(model.parameters()))
278
+
273
279
  # the chunk size within the paper where adaptive step, momentum, weight decay are shared
274
280
 
275
281
  self.chunk_size = chunk_size
@@ -299,15 +305,14 @@ class NeuralMemory(Module):
299
305
  nn.init.normal_(self.empty_memory_embed, std = 0.02)
300
306
 
301
307
  # learned adaptive learning rate and momentum
302
- # todo - explore mlp layerwise learned lr / momentum
303
308
 
304
- self.to_momentum = nn.Sequential(
309
+ self.to_momentum = Sequential(
305
310
  Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
306
311
  LinearNoBias(dim, heads),
307
312
  Rearrange('b n h -> (b h) n 1')
308
313
  )
309
314
 
310
- self.to_adaptive_step = nn.Sequential(
315
+ self.to_adaptive_step = Sequential(
311
316
  LinearNoBias(dim, heads),
312
317
  Rearrange('b n h -> (b h) n')
313
318
  )
@@ -317,13 +322,24 @@ class NeuralMemory(Module):
317
322
 
318
323
  self.adaptive_step_transform = adaptive_step_transform
319
324
 
325
+ # per layer learning rate modulation
326
+
327
+ self.to_layer_modulation = Sequential(
328
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
329
+ LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
330
+ Rearrange('b n (h w) -> w (b h) n', h = heads),
331
+ nn.Sigmoid()
332
+ ) if per_parameter_lr_modulation else None
333
+
334
+ self.max_mem_layer_modulation = max_mem_layer_modulation
335
+
320
336
  # allow for softclamp the gradient norms for storing memories
321
337
 
322
338
  self.max_grad_norm = max_grad_norm
323
339
 
324
340
  # weight decay factor
325
341
 
326
- self.to_decay_factor = nn.Sequential(
342
+ self.to_decay_factor = Sequential(
327
343
  Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
328
344
  LinearNoBias(dim, heads),
329
345
  Rearrange('b n h -> (b h) n 1')
@@ -387,6 +403,11 @@ class NeuralMemory(Module):
387
403
  adaptive_momentum = self.to_momentum(seq).sigmoid()
388
404
  decay_factor = self.to_decay_factor(seq).sigmoid()
389
405
 
406
+ need_layer_lr_mod = exists(self.to_layer_modulation)
407
+
408
+ if need_layer_lr_mod:
409
+ layer_lr_mod = self.to_layer_modulation(seq) * self.max_mem_layer_modulation
410
+
390
411
  # keys and values
391
412
 
392
413
  keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
@@ -418,6 +439,11 @@ class NeuralMemory(Module):
418
439
 
419
440
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
420
441
 
442
+ # maybe per layer modulation
443
+
444
+ if need_layer_lr_mod:
445
+ grads = TensorDict({name: einx.multiply('b h, b h ... -> b h ...', layer_lr_mod, t) for layer_lr_mod, (name, t) in zip(layer_lr_mod, grads.items())})
446
+
421
447
  # negative gradients, adaptive lr already applied as loss weight
422
448
 
423
449
  surprises = grads.apply(lambda t: -t)
@@ -469,7 +495,7 @@ class NeuralMemory(Module):
469
495
 
470
496
  # use associative scan again for learned forgetting (weight decay) - eq (13)
471
497
 
472
- update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
498
+ update = scan_fn(1. - decay_factor, momentum)
473
499
 
474
500
  updates[param_name] = inverse_pack(update)
475
501
  next_momentum[param_name] = inverse_pack(momentum)
@@ -566,7 +592,12 @@ class NeuralMemory(Module):
566
592
  batch, seq_len = seq.shape[:2]
567
593
 
568
594
  if seq_len < self.chunk_size:
569
- return self.init_empty_memory_embed(batch, seq_len)
595
+ out = self.init_empty_memory_embed(batch, seq_len)
596
+
597
+ if not return_aux_kv_loss:
598
+ return out
599
+
600
+ return out, self.zero
570
601
 
571
602
  if exists(past_state):
572
603
  past_state = tuple(TensorDict(d) for d in past_state)
@@ -580,7 +611,6 @@ class NeuralMemory(Module):
580
611
 
581
612
  past_weights, _ = past_state
582
613
 
583
-
584
614
  retrieved = self.retrieve_memories(seq, past_weights + updates)
585
615
 
586
616
  if not return_aux_kv_loss:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.65
3
+ Version: 0.1.1
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -37,14 +37,15 @@ Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.5
39
39
  Requires-Dist: einops>=0.8.0
40
+ Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: hyper-connections>=0.1.8
41
42
  Requires-Dist: ninja
42
43
  Requires-Dist: rotary-embedding-torch
43
44
  Requires-Dist: tensordict
44
45
  Requires-Dist: torch>=2.2
46
+ Requires-Dist: tqdm
45
47
  Requires-Dist: x-transformers
46
48
  Provides-Extra: examples
47
- Requires-Dist: tqdm; extra == 'examples'
48
49
  Requires-Dist: wandb; extra == 'examples'
49
50
  Provides-Extra: test
50
51
  Requires-Dist: pytest; extra == 'test'
@@ -58,6 +59,10 @@ Description-Content-Type: text/markdown
58
59
 
59
60
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
60
61
 
62
+ ## Appreciation
63
+
64
+ - [@sentialx](https://github.com/sentialx) for sharing his early experimental results with me
65
+
61
66
  ## Install
62
67
 
63
68
  ```bash
@@ -99,7 +104,12 @@ transformer = MemoryAsContextTransformer(
99
104
 
100
105
  token_ids = torch.randint(0, 256, (1, 1023))
101
106
 
102
- logits = transformer(token_ids) # (1, 1023, 256)
107
+ loss = transformer(token_ids, return_loss = True) # (1, 1023, 256)
108
+ loss.backward()
109
+
110
+ # after much training
111
+
112
+ sampled = transformer.sample(token_ids[:, :4], 512)
103
113
  ```
104
114
 
105
115
  ## Experiments
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=f_m559p9PLW1yxW6tQfrD43v0eMpeTUxpYdtb32UFgg,19031
4
+ titans_pytorch/titans.py,sha256=cGWJHkOYmIeE6X383mZvyjusECBwbplVvK0cfgfhBxg,18634
5
+ titans_pytorch-0.1.1.dist-info/METADATA,sha256=fxIRTFPGxq9RZMK4yRaAnoG3ym7dfOdaYiatBwmkS6Q,4684
6
+ titans_pytorch-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.1.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=7PHBCbeB1LhHY5s3zAyYF0L3Mm7CNy4TOBbcpLX6bNE,17686
4
- titans_pytorch/titans.py,sha256=y6lJRErIoM6T2aTVFlf1GxSB0cpsmBZdSIj1DCHUCQ8,17486
5
- titans_pytorch-0.0.65.dist-info/METADATA,sha256=oDjEiufwOninsFDoCGbu691LXc1mey2OT7j6PNzkz0Q,4457
6
- titans_pytorch-0.0.65.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.65.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.65.dist-info/RECORD,,