titans-pytorch 0.0.58__tar.gz → 0.0.61__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.58
3
+ Version: 0.0.61
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.58"
3
+ version = "0.0.61"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -24,8 +24,8 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
24
24
 
25
25
  def create_mac_mask(b, h, q_idx, kv_idx):
26
26
  is_persist_mem = kv_idx < persist_mem_len
27
- causal_mask = q_idx >= (kv_idx - is_persist_mem)
28
- block_diagonal = (q_idx // window_size) == ((kv_idx - is_persist_mem) // window_size)
27
+ causal_mask = q_idx >= (kv_idx - persist_mem_len)
28
+ block_diagonal = (q_idx // window_size) == ((kv_idx - persist_mem_len) // window_size)
29
29
  return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
30
30
 
31
31
  block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
@@ -489,7 +489,7 @@ class MemoryAsContextTransformer(Module):
489
489
  flex_attn_fn = None
490
490
 
491
491
  if use_flex_attn:
492
- block_mask = create_mac_block_mask(seq_len_with_mem, self.segment_len, self.num_persist_mem_tokens)
492
+ block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens)
493
493
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
494
494
 
495
495
  # value residual
@@ -289,6 +289,8 @@ class NeuralMemory(Module):
289
289
 
290
290
  self.use_accelerated_scan = use_accelerated_scan
291
291
 
292
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
293
+
292
294
  def init_weights_and_momentum(self):
293
295
  params = TensorDict(dict(self.memory_model.named_parameters()))
294
296
 
@@ -306,6 +308,13 @@ class NeuralMemory(Module):
306
308
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
307
309
  return_aux_kv_loss = False
308
310
  ):
311
+ seq_len = seq.shape[-2]
312
+
313
+ # handle edge case
314
+
315
+ if seq_len < self.chunk_size:
316
+ past_weight, _ = past_state
317
+ return TensorDict(past_weight).clone().zero_(), self.zero
309
318
 
310
319
  seq = self.store_norm(seq)
311
320
 
@@ -425,12 +434,10 @@ class NeuralMemory(Module):
425
434
 
426
435
  last_update = updates.apply(lambda t: t[:, -1])
427
436
 
428
- next_state = (curr_weights + last_update, next_momentum)
429
-
430
437
  if not return_aux_kv_loss:
431
- return updates, next_state
438
+ return updates
432
439
 
433
- return updates, next_state, aux_kv_recon_loss.mean()
440
+ return updates, aux_kv_recon_loss.mean()
434
441
 
435
442
  def retrieve_memories(
436
443
  self,
@@ -442,7 +449,8 @@ class NeuralMemory(Module):
442
449
 
443
450
  seq = self.retrieve_norm(seq)
444
451
 
445
- assert seq_len >= chunk_size
452
+ if seq_len < self.chunk_size:
453
+ return self.init_empty_memory_embed(batch, seq_len)
446
454
 
447
455
  seq = seq[:, (chunk_size - 1):]
448
456
  curtailed_seq_len = seq.shape[-2]
@@ -524,7 +532,7 @@ class NeuralMemory(Module):
524
532
 
525
533
  store_seq = default(store_seq, seq)
526
534
 
527
- updates, next_memories, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
535
+ updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
528
536
 
529
537
  past_weights, _ = past_state
530
538
 
@@ -24,16 +24,27 @@ GENERATE_LENGTH = 512
24
24
  SHOULD_GENERATE = True
25
25
  SEQ_LEN = 512
26
26
 
27
- PROJECT_NAME = 'titans-mac-transformer'
28
- WANDB_ONLINE = False # turn this on to pipe experiment to cloud
27
+ # neural memory related
28
+
29
29
  NEURAL_MEMORY_DEPTH = 2
30
30
  NUM_PERSIST_MEM = 4
31
31
  NUM_LONGTERM_MEM = 4
32
32
  NEURAL_MEM_LAYERS = (2, 4)
33
+ NEURAL_MEM_GATE_ATTN_OUTPUT = True
33
34
  WINDOW_SIZE = 32
34
35
  KV_RECON_LOSS_WEIGHT = 0.
35
36
  LEARNED_MEM_MODEL_WEIGHTS = True
37
+
38
+ # experiment related
39
+
40
+ PROJECT_NAME = 'titans-mac-transformer'
36
41
  RUN_NAME = f'mac - {NUM_LONGTERM_MEM} longterm mems, layers {NEURAL_MEM_LAYERS}'
42
+ WANDB_ONLINE = False # turn this on to pipe experiment to cloud
43
+
44
+ # perf related
45
+
46
+ USE_ACCELERATED_SCAN = True
47
+ USE_FLEX_ATTN = True
37
48
 
38
49
  # wandb experiment tracker
39
50
 
@@ -112,10 +123,13 @@ model = MemoryAsContextTransformer(
112
123
  num_longterm_mem_tokens = NUM_LONGTERM_MEM,
113
124
  neural_memory_layers = NEURAL_MEM_LAYERS,
114
125
  neural_memory_segment_len = WINDOW_SIZE // 2,
126
+ neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
115
127
  aux_kv_recon_loss_weight = KV_RECON_LOSS_WEIGHT,
128
+ use_flex_attn = USE_FLEX_ATTN,
116
129
  neural_memory_kwargs = dict(
117
130
  dim_head = 64,
118
131
  heads = 4,
132
+ use_accelerated_scan = USE_ACCELERATED_SCAN,
119
133
  learned_mem_model_weights = LEARNED_MEM_MODEL_WEIGHTS,
120
134
  default_model_kwargs = dict(
121
135
  depth = NEURAL_MEMORY_DEPTH,
File without changes