titans-pytorch 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,8 @@ from typing import Callable
3
3
  from math import ceil
4
4
  from functools import partial
5
5
 
6
+ import tqdm
7
+
6
8
  import torch
7
9
  from torch import nn, cat
8
10
  import torch.nn.functional as F
@@ -88,12 +90,6 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
88
90
 
89
91
  def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
90
92
  batch, seq_len = seq.shape[:2]
91
-
92
- need_segment = seq_len >= segment_len
93
-
94
- if not need_segment:
95
- return seq, identity
96
-
97
93
  next_seq_len_mult = round_up_multiple(seq_len, segment_len)
98
94
 
99
95
  padding = next_seq_len_mult - seq_len
@@ -116,6 +112,29 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
116
112
 
117
113
  return seq, inverse
118
114
 
115
+ # sampling related
116
+
117
+ def log(t, eps = 1e-20):
118
+ return torch.log(t.clamp(min = eps))
119
+
120
+ def gumbel_noise(t):
121
+ noise = torch.rand_like(t)
122
+ return -log(-log(noise))
123
+
124
+ def gumbel_sample(t, temperature = 1.):
125
+ if temperature > 0.:
126
+ t = t / temperature + gumbel_noise(t)
127
+ return t.argmax(dim = -1, keepdim = True)
128
+
129
+ # min_p
130
+ # https://arxiv.org/abs/2407.01082
131
+
132
+ def min_p_filter(logits, min_p = 0.1):
133
+ probs = logits.softmax(dim = -1)
134
+ max_probs = probs.amax(dim = -1, keepdim = True)
135
+ limit = min_p * max_probs
136
+ return torch.where(probs < limit, float('-inf'), logits)
137
+
119
138
  # feedforward and attention
120
139
 
121
140
  class GEGLU(Module):
@@ -500,6 +519,36 @@ class MemoryAsContextTransformer(Module):
500
519
  self.segment_len = segment_len
501
520
  self.num_persist_mem_tokens = num_persist_mem_tokens
502
521
 
522
+ @torch.no_grad()
523
+ def sample(
524
+ self,
525
+ prompt: Tensor,
526
+ seq_len: int,
527
+ temperature = 1.5,
528
+ filter_fn: Callable = min_p_filter,
529
+ filter_kwargs: dict = dict(
530
+ min_p = 0.1,
531
+ )
532
+ ):
533
+ was_training = self.training
534
+ self.eval()
535
+
536
+ prompt_seq_len, out = prompt.shape[-1], prompt.clone()
537
+ sample_num_times = max(0, seq_len - prompt_seq_len)
538
+
539
+ for _ in tqdm.tqdm(range(sample_num_times)):
540
+ logits = self.forward(out, disable_flex_attn = True)
541
+ logits = logits[:, -1]
542
+
543
+ logits = filter_fn(logits, **filter_kwargs)
544
+ sample = gumbel_sample(logits, temperature = temperature)
545
+
546
+ out = torch.cat((out, sample), dim = -1)
547
+
548
+ self.train(was_training)
549
+
550
+ return out[..., prompt_seq_len:]
551
+
503
552
  def forward(
504
553
  self,
505
554
  x,
titans_pytorch/titans.py CHANGED
@@ -592,7 +592,12 @@ class NeuralMemory(Module):
592
592
  batch, seq_len = seq.shape[:2]
593
593
 
594
594
  if seq_len < self.chunk_size:
595
- return self.init_empty_memory_embed(batch, seq_len)
595
+ out = self.init_empty_memory_embed(batch, seq_len)
596
+
597
+ if not return_aux_kv_loss:
598
+ return out
599
+
600
+ return out, self.zero
596
601
 
597
602
  if exists(past_state):
598
603
  past_state = tuple(TensorDict(d) for d in past_state)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -43,9 +43,9 @@ Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
45
45
  Requires-Dist: torch>=2.2
46
+ Requires-Dist: tqdm
46
47
  Requires-Dist: x-transformers
47
48
  Provides-Extra: examples
48
- Requires-Dist: tqdm; extra == 'examples'
49
49
  Requires-Dist: wandb; extra == 'examples'
50
50
  Provides-Extra: test
51
51
  Requires-Dist: pytest; extra == 'test'
@@ -104,7 +104,12 @@ transformer = MemoryAsContextTransformer(
104
104
 
105
105
  token_ids = torch.randint(0, 256, (1, 1023))
106
106
 
107
- logits = transformer(token_ids) # (1, 1023, 256)
107
+ loss = transformer(token_ids, return_loss = True) # (1, 1023, 256)
108
+ loss.backward()
109
+
110
+ # after much training
111
+
112
+ sampled = transformer.sample(token_ids[:, :4], 512)
108
113
  ```
109
114
 
110
115
  ## Experiments
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=f_m559p9PLW1yxW6tQfrD43v0eMpeTUxpYdtb32UFgg,19031
4
+ titans_pytorch/titans.py,sha256=cGWJHkOYmIeE6X383mZvyjusECBwbplVvK0cfgfhBxg,18634
5
+ titans_pytorch-0.1.1.dist-info/METADATA,sha256=fxIRTFPGxq9RZMK4yRaAnoG3ym7dfOdaYiatBwmkS6Q,4684
6
+ titans_pytorch-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.1.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=7PHBCbeB1LhHY5s3zAyYF0L3Mm7CNy4TOBbcpLX6bNE,17686
4
- titans_pytorch/titans.py,sha256=L3Mu6pDnimD4MNn_832trFEJgXOPjxSdTrB9jiSUSTk,18533
5
- titans_pytorch-0.1.0.dist-info/METADATA,sha256=LuWDzv-NbGxYKOMThN_WKQWDueyIsOAMSwwiE_BDraI,4595
6
- titans_pytorch-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.0.dist-info/RECORD,,