titans-pytorch 0.0.51__tar.gz → 0.0.52__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.51
3
+ Version: 0.0.52
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.51"
3
+ version = "0.0.52"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -128,6 +128,7 @@ class SegmentedAttention(Module):
128
128
  heads = 8,
129
129
  accept_value_residual = False,
130
130
  attend_kwargs: dict = dict(),
131
+ use_flex_attn = False
131
132
  ):
132
133
  super().__init__()
133
134
  self.norm = nn.RMSNorm(dim)
@@ -157,11 +158,79 @@ class SegmentedAttention(Module):
157
158
 
158
159
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
159
160
 
161
+ # flex attn related
162
+
163
+ assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
164
+ self.use_flex_attn = use_flex_attn
165
+
166
+ self.segment_len = segment_len
167
+ self.num_persist_mem_tokens = num_persist_mem_tokens
168
+
169
+ def forward_flex(
170
+ self,
171
+ seq,
172
+ value_residual = None,
173
+ flex_attn_fn: Callable | None = None
174
+ ):
175
+
176
+ assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
177
+
178
+ batch, seq_len = seq.shape[:2]
179
+
180
+ # attention
181
+
182
+ seq = self.norm(seq)
183
+
184
+ q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
185
+ q, k, v = map(self.split_heads, (q, k, v))
186
+
187
+ # value residual
188
+
189
+ orig_v = v
190
+
191
+ if exists(self.to_learned_v_mix):
192
+ mix = self.to_learned_v_mix(seq)
193
+ v = v.lerp(value_residual, mix)
194
+
195
+ # take care of persistent memory key / values
196
+
197
+ pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b = batch)
198
+
199
+ # relative positions
200
+
201
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
202
+
203
+ # persistent memory
204
+
205
+ k = cat((pmk, k), dim = -2)
206
+ v = cat((pmv, v), dim = -2)
207
+
208
+ # prep flex attention
209
+
210
+ if not exists(flex_attn_fn):
211
+ block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
212
+
213
+ flex_attn_fn = partial(flex_attention, block_mask = block_mask)
214
+
215
+ # attention
216
+
217
+ out = flex_attn_fn(q, k, v)
218
+
219
+ out = self.merge_heads(out)
220
+
221
+ out = self.to_out(out)
222
+
223
+ return out, orig_v
224
+
160
225
  def forward(
161
226
  self,
162
227
  seq,
163
- value_residual = None
228
+ value_residual = None,
229
+ flex_attn_fn: Callable | None = None
164
230
  ):
231
+ if seq.is_cuda and self.use_flex_attn:
232
+ return self.forward_flex(seq, value_residual, flex_attn_fn)
233
+
165
234
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
166
235
 
167
236
  segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
@@ -191,7 +260,7 @@ class SegmentedAttention(Module):
191
260
 
192
261
  # take care of persistent memory key / values
193
262
 
194
- pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
263
+ pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = seq.shape[0])
195
264
 
196
265
  # relative positions
197
266
 
File without changes