titans-pytorch 0.0.50__py3-none-any.whl → 0.0.52__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,16 +7,48 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
+ # flex attention
11
+ # https://pytorch.org/blog/flexattention/
12
+
13
+ flex_attention = None
14
+
15
+ try:
16
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
17
+ if torch.cuda.is_available():
18
+ flex_attention = torch.compile(flex_attention)
19
+ except ImportError:
20
+ pass
21
+
22
+ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
23
+
24
+ def create_mac_mask(b, h, q_idx, kv_idx):
25
+ is_persist_mem = kv_idx < persist_mem_len
26
+ causal_mask = q_idx >= (kv_idx - is_persist_mem)
27
+ block_diagonal = (q_idx // window_size) == ((kv_idx - is_persist_mem) // window_size)
28
+ return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
29
+
30
+ block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
31
+ return block_mask
32
+
33
+ # einstein notation related
34
+
10
35
  from einops import einsum, repeat, rearrange, pack, unpack
11
36
  from einops.layers.torch import Rearrange
12
37
 
13
- from hyper_connections import get_init_and_expand_reduce_stream_functions
38
+ # b - batch
39
+ # n - sequence
40
+ # h - heads
41
+ # d - feature dimension
14
42
 
15
43
  # absolute and relative positions
16
44
 
17
45
  from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
46
  from rotary_embedding_torch import RotaryEmbedding
47
+
48
+ # hyper connections / attend from x-transformers, which handles different queries and key lengths better
49
+
19
50
  from x_transformers.attend import Attend
51
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
20
52
 
21
53
  # proposed neural memory
22
54
 
@@ -96,6 +128,7 @@ class SegmentedAttention(Module):
96
128
  heads = 8,
97
129
  accept_value_residual = False,
98
130
  attend_kwargs: dict = dict(),
131
+ use_flex_attn = False
99
132
  ):
100
133
  super().__init__()
101
134
  self.norm = nn.RMSNorm(dim)
@@ -125,11 +158,79 @@ class SegmentedAttention(Module):
125
158
 
126
159
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
127
160
 
161
+ # flex attn related
162
+
163
+ assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
164
+ self.use_flex_attn = use_flex_attn
165
+
166
+ self.segment_len = segment_len
167
+ self.num_persist_mem_tokens = num_persist_mem_tokens
168
+
169
+ def forward_flex(
170
+ self,
171
+ seq,
172
+ value_residual = None,
173
+ flex_attn_fn: Callable | None = None
174
+ ):
175
+
176
+ assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
177
+
178
+ batch, seq_len = seq.shape[:2]
179
+
180
+ # attention
181
+
182
+ seq = self.norm(seq)
183
+
184
+ q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
185
+ q, k, v = map(self.split_heads, (q, k, v))
186
+
187
+ # value residual
188
+
189
+ orig_v = v
190
+
191
+ if exists(self.to_learned_v_mix):
192
+ mix = self.to_learned_v_mix(seq)
193
+ v = v.lerp(value_residual, mix)
194
+
195
+ # take care of persistent memory key / values
196
+
197
+ pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b = batch)
198
+
199
+ # relative positions
200
+
201
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
202
+
203
+ # persistent memory
204
+
205
+ k = cat((pmk, k), dim = -2)
206
+ v = cat((pmv, v), dim = -2)
207
+
208
+ # prep flex attention
209
+
210
+ if not exists(flex_attn_fn):
211
+ block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
212
+
213
+ flex_attn_fn = partial(flex_attention, block_mask = block_mask)
214
+
215
+ # attention
216
+
217
+ out = flex_attn_fn(q, k, v)
218
+
219
+ out = self.merge_heads(out)
220
+
221
+ out = self.to_out(out)
222
+
223
+ return out, orig_v
224
+
128
225
  def forward(
129
226
  self,
130
227
  seq,
131
- value_residual = None
228
+ value_residual = None,
229
+ flex_attn_fn: Callable | None = None
132
230
  ):
231
+ if seq.is_cuda and self.use_flex_attn:
232
+ return self.forward_flex(seq, value_residual, flex_attn_fn)
233
+
133
234
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
134
235
 
135
236
  segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
@@ -159,7 +260,7 @@ class SegmentedAttention(Module):
159
260
 
160
261
  # take care of persistent memory key / values
161
262
 
162
- pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
263
+ pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = seq.shape[0])
163
264
 
164
265
  # relative positions
165
266
 
titans_pytorch/titans.py CHANGED
@@ -56,6 +56,17 @@ def pack_one_with_inverse(t, pattern):
56
56
 
57
57
  return packed, inverse
58
58
 
59
+ def Sequential(*modules):
60
+ modules = [*filter(exists, modules)]
61
+
62
+ if len(modules) == 0:
63
+ return nn.Identity()
64
+
65
+ if len(modules) == 1:
66
+ return modules[0]
67
+
68
+ return nn.Sequential(*modules)
69
+
59
70
  # softclamping gradients
60
71
 
61
72
  def softclamp_max(t, max_value):
@@ -124,9 +135,6 @@ class MemoryAttention(Module):
124
135
  ])
125
136
 
126
137
  def forward(self, x):
127
-
128
- assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
129
-
130
138
  wq, wk, wv, ffw1, ffw2 = self.weights
131
139
 
132
140
  q = F.normalize(x @ wq, dim = -1)
@@ -168,6 +176,7 @@ class NeuralMemory(Module):
168
176
  post_rmsnorm = True,
169
177
  max_grad_norm: float | None = None,
170
178
  use_accelerated_scan = False,
179
+ activation: Module | None = None,
171
180
  default_model_kwargs: dict = dict(
172
181
  depth = 2
173
182
  )
@@ -225,11 +234,11 @@ class NeuralMemory(Module):
225
234
 
226
235
  # queries for retrieving from the model
227
236
 
228
- self.to_queries = LinearNoBias(dim, dim_inner)
237
+ self.to_queries = Sequential(LinearNoBias(dim, dim_inner), activation)
229
238
 
230
239
  # keys and values for storing to the model
231
240
 
232
- self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
241
+ self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
233
242
  self.store_memory_loss_fn = store_memory_loss_fn
234
243
 
235
244
  # empty memory embed
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.50
3
+ Version: 0.0.52
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -135,3 +135,12 @@ $ python train_mac.py
135
135
  year = {2024}
136
136
  }
137
137
  ```
138
+
139
+ ```bibtex
140
+ @inproceedings{Yang2024GatedDN,
141
+ title = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
142
+ author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
143
+ year = {2024},
144
+ url = {https://api.semanticscholar.org/CorpusID:274598177}
145
+ }
146
+ ```
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=khfjpbsy-uT9NIG3dZLsLOG_XSEi7EqcyfbPr7EQc2Q,13192
4
+ titans_pytorch/titans.py,sha256=T04onF0xhcrosS-Qkx7fcx-Cqgh0TdU5JLdq9l8ayGg,15911
5
+ titans_pytorch-0.0.52.dist-info/METADATA,sha256=coC9ExIuNvmab0BktSE1NwUgxRaBUV7h_cTHeoJkRJo,4484
6
+ titans_pytorch-0.0.52.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.52.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.52.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=EMhxPt86Vr6LFvPm0OLMFYLaIY19khU9yIHkIhl2EMA,10316
4
- titans_pytorch/titans.py,sha256=TklMAxNDxgFBpJZFJa8hEhqA_DITmT6EM0p0ueE1jo8,15712
5
- titans_pytorch-0.0.50.dist-info/METADATA,sha256=KU7TTrH89eNVPP10NKKTDKnW-ik344_kVQkAXW7NRL8,4210
6
- titans_pytorch-0.0.50.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.50.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.50.dist-info/RECORD,,