titans-pytorch 0.0.55__tar.gz → 0.0.56__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.55
3
+ Version: 0.0.56
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.55"
3
+ version = "0.0.56"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -445,8 +445,7 @@ class MemoryAsContextTransformer(Module):
445
445
  flex_attn_fn = None
446
446
 
447
447
  if use_flex_attn:
448
- block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
449
-
448
+ block_mask = create_mac_block_mask(seq_len_with_mem, self.segment_len, self.num_persist_mem_tokens)
450
449
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
451
450
 
452
451
  # value residual
@@ -467,7 +466,12 @@ class MemoryAsContextTransformer(Module):
467
466
  x, aux_kv_loss = maybe_neural_mem(x, return_aux_kv_loss = True)
468
467
  kv_recon_losses = kv_recon_losses + aux_kv_loss
469
468
 
470
- x, values = attn(x, value_residual = value_residual, disable_flex_attn = disable_flex_attn, flex_attn_fn = flex_attn_fn)
469
+ x, values = attn(
470
+ x,
471
+ value_residual = value_residual,
472
+ disable_flex_attn = disable_flex_attn,
473
+ flex_attn_fn = flex_attn_fn
474
+ )
471
475
 
472
476
  value_residual = default(value_residual, values)
473
477
 
@@ -113,7 +113,6 @@ model = MemoryAsContextTransformer(
113
113
  neural_memory_layers = NEURAL_MEM_LAYERS,
114
114
  neural_memory_segment_len = WINDOW_SIZE // 2,
115
115
  aux_kv_recon_loss_weight = KV_RECON_LOSS_WEIGHT,
116
- use_flex_attn = True,
117
116
  neural_memory_kwargs = dict(
118
117
  dim_head = 64,
119
118
  heads = 4,
File without changes