titans-pytorch 0.0.54__py3-none-any.whl → 0.0.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -227,9 +227,10 @@ class SegmentedAttention(Module):
227
227
  self,
228
228
  seq,
229
229
  value_residual = None,
230
- flex_attn_fn: Callable | None = None
230
+ flex_attn_fn: Callable | None = None,
231
+ disable_flex_attn = False
231
232
  ):
232
- if seq.is_cuda and self.use_flex_attn:
233
+ if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
233
234
  return self.forward_flex(seq, value_residual, flex_attn_fn)
234
235
 
235
236
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
@@ -303,7 +304,8 @@ class MemoryAsContextTransformer(Module):
303
304
  num_residual_streams = 4,
304
305
  neural_memory_kwargs: dict = dict(),
305
306
  neural_memory_layers: tuple[int, ...] | None = None,
306
- aux_kv_recon_loss_weight = 0.
307
+ aux_kv_recon_loss_weight = 0.,
308
+ use_flex_attn = False
307
309
  ):
308
310
  super().__init__()
309
311
 
@@ -336,6 +338,8 @@ class MemoryAsContextTransformer(Module):
336
338
 
337
339
  assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
338
340
 
341
+ # mem, attn, and feedforward layers
342
+
339
343
  for layer in layers:
340
344
  is_first = layer == 1
341
345
 
@@ -363,6 +367,7 @@ class MemoryAsContextTransformer(Module):
363
367
  dim_head = dim_head,
364
368
  heads = heads,
365
369
  segment_len = segment_len,
370
+ use_flex_attn = use_flex_attn,
366
371
  accept_value_residual = not is_first,
367
372
  num_longterm_mem_tokens = num_longterm_mem_tokens,
368
373
  num_persist_mem_tokens = num_persist_mem_tokens
@@ -386,11 +391,20 @@ class MemoryAsContextTransformer(Module):
386
391
 
387
392
  self.register_buffer('zero', torch.tensor(0.), persistent = False)
388
393
 
394
+ # flex attn related
395
+
396
+ assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
397
+ self.use_flex_attn = use_flex_attn
398
+
399
+ self.segment_len = segment_len
400
+ self.num_persist_mem_tokens = num_persist_mem_tokens
401
+
389
402
  def forward(
390
403
  self,
391
404
  x,
392
405
  return_loss = False,
393
- return_loss_breakdown = False
406
+ return_loss_breakdown = False,
407
+ disable_flex_attn = False
394
408
  ):
395
409
 
396
410
  if return_loss:
@@ -424,6 +438,16 @@ class MemoryAsContextTransformer(Module):
424
438
 
425
439
  x = x + pos_emb[:seq_len_with_mem]
426
440
 
441
+ # prep flex attention
442
+
443
+ use_flex_attn = x.is_cuda and self.use_flex_attn and not disable_flex_attn
444
+
445
+ flex_attn_fn = None
446
+
447
+ if use_flex_attn:
448
+ block_mask = create_mac_block_mask(seq_len_with_mem, self.segment_len, self.num_persist_mem_tokens)
449
+ flex_attn_fn = partial(flex_attention, block_mask = block_mask)
450
+
427
451
  # value residual
428
452
 
429
453
  value_residual = None
@@ -442,7 +466,12 @@ class MemoryAsContextTransformer(Module):
442
466
  x, aux_kv_loss = maybe_neural_mem(x, return_aux_kv_loss = True)
443
467
  kv_recon_losses = kv_recon_losses + aux_kv_loss
444
468
 
445
- x, values = attn(x, value_residual = value_residual)
469
+ x, values = attn(
470
+ x,
471
+ value_residual = value_residual,
472
+ disable_flex_attn = disable_flex_attn,
473
+ flex_attn_fn = flex_attn_fn
474
+ )
446
475
 
447
476
  value_residual = default(value_residual, values)
448
477
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.54
3
+ Version: 0.0.56
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=_Vsco5YuR6uxouWcjFj-s-zPhrBcaapIzqoyi7qqY0Q,14245
4
+ titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
5
+ titans_pytorch-0.0.56.dist-info/METADATA,sha256=QlCmHqajHiaZTps0W9gKXIHE6dShZER3PqPoYi2zRe4,4457
6
+ titans_pytorch-0.0.56.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.56.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.56.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=GdUAYq6MDRGY0l2ESBH_kM01AEzVztiKmWfblSKxBEM,13212
4
- titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
5
- titans_pytorch-0.0.54.dist-info/METADATA,sha256=bxbC3NBO4Sjii7DpFPcmNsO9M1kX76vj947H2DeUceg,4457
6
- titans_pytorch-0.0.54.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.54.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.54.dist-info/RECORD,,