titans-pytorch 0.0.53__tar.gz → 0.0.55__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.53
3
+ Version: 0.0.55
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.53"
3
+ version = "0.0.55"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -6,17 +6,20 @@ from titans_pytorch import NeuralMemory
6
6
 
7
7
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
8
8
  @pytest.mark.parametrize('silu', (False, True))
9
+ @pytest.mark.parametrize('learned_mem_model_weights', (False, True))
9
10
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
10
11
  def test_titans(
11
12
  seq_len,
12
13
  silu,
14
+ learned_mem_model_weights,
13
15
  max_grad_norm,
14
16
  ):
15
17
  mem = NeuralMemory(
16
18
  dim = 384,
17
19
  chunk_size = 64,
18
20
  activation = nn.SiLU() if silu else None,
19
- max_grad_norm = max_grad_norm
21
+ max_grad_norm = max_grad_norm,
22
+ learned_mem_model_weights = learned_mem_model_weights
20
23
  )
21
24
 
22
25
  seq = torch.randn(2, seq_len, 384)
@@ -1,6 +1,7 @@
1
1
  from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
+ MemoryAttention
4
5
  )
5
6
 
6
7
  from titans_pytorch.mac_transformer import (
@@ -227,9 +227,10 @@ class SegmentedAttention(Module):
227
227
  self,
228
228
  seq,
229
229
  value_residual = None,
230
- flex_attn_fn: Callable | None = None
230
+ flex_attn_fn: Callable | None = None,
231
+ disable_flex_attn = False
231
232
  ):
232
- if seq.is_cuda and self.use_flex_attn:
233
+ if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
233
234
  return self.forward_flex(seq, value_residual, flex_attn_fn)
234
235
 
235
236
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
@@ -303,7 +304,8 @@ class MemoryAsContextTransformer(Module):
303
304
  num_residual_streams = 4,
304
305
  neural_memory_kwargs: dict = dict(),
305
306
  neural_memory_layers: tuple[int, ...] | None = None,
306
- aux_kv_recon_loss_weight = 0.
307
+ aux_kv_recon_loss_weight = 0.,
308
+ use_flex_attn = False
307
309
  ):
308
310
  super().__init__()
309
311
 
@@ -336,6 +338,8 @@ class MemoryAsContextTransformer(Module):
336
338
 
337
339
  assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
338
340
 
341
+ # mem, attn, and feedforward layers
342
+
339
343
  for layer in layers:
340
344
  is_first = layer == 1
341
345
 
@@ -363,6 +367,7 @@ class MemoryAsContextTransformer(Module):
363
367
  dim_head = dim_head,
364
368
  heads = heads,
365
369
  segment_len = segment_len,
370
+ use_flex_attn = use_flex_attn,
366
371
  accept_value_residual = not is_first,
367
372
  num_longterm_mem_tokens = num_longterm_mem_tokens,
368
373
  num_persist_mem_tokens = num_persist_mem_tokens
@@ -386,11 +391,20 @@ class MemoryAsContextTransformer(Module):
386
391
 
387
392
  self.register_buffer('zero', torch.tensor(0.), persistent = False)
388
393
 
394
+ # flex attn related
395
+
396
+ assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
397
+ self.use_flex_attn = use_flex_attn
398
+
399
+ self.segment_len = segment_len
400
+ self.num_persist_mem_tokens = num_persist_mem_tokens
401
+
389
402
  def forward(
390
403
  self,
391
404
  x,
392
405
  return_loss = False,
393
- return_loss_breakdown = False
406
+ return_loss_breakdown = False,
407
+ disable_flex_attn = False
394
408
  ):
395
409
 
396
410
  if return_loss:
@@ -424,6 +438,17 @@ class MemoryAsContextTransformer(Module):
424
438
 
425
439
  x = x + pos_emb[:seq_len_with_mem]
426
440
 
441
+ # prep flex attention
442
+
443
+ use_flex_attn = x.is_cuda and self.use_flex_attn and not disable_flex_attn
444
+
445
+ flex_attn_fn = None
446
+
447
+ if use_flex_attn:
448
+ block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
449
+
450
+ flex_attn_fn = partial(flex_attention, block_mask = block_mask)
451
+
427
452
  # value residual
428
453
 
429
454
  value_residual = None
@@ -442,7 +467,7 @@ class MemoryAsContextTransformer(Module):
442
467
  x, aux_kv_loss = maybe_neural_mem(x, return_aux_kv_loss = True)
443
468
  kv_recon_losses = kv_recon_losses + aux_kv_loss
444
469
 
445
- x, values = attn(x, value_residual = value_residual)
470
+ x, values = attn(x, value_residual = value_residual, disable_flex_attn = disable_flex_attn, flex_attn_fn = flex_attn_fn)
446
471
 
447
472
  value_residual = default(value_residual, values)
448
473
 
@@ -123,9 +123,12 @@ class MemoryMLP(Module):
123
123
  class MemoryAttention(Module):
124
124
  def __init__(
125
125
  self,
126
- dim
126
+ dim,
127
+ scale = 8.
127
128
  ):
128
129
  super().__init__()
130
+ self.scale = scale
131
+
129
132
  self.weights = nn.ParameterList([
130
133
  nn.Parameter(torch.randn(dim, dim)), # queries
131
134
  nn.Parameter(torch.randn(dim, dim)), # keys
@@ -143,6 +146,7 @@ class MemoryAttention(Module):
143
146
 
144
147
  attn_out = F.scaled_dot_product_attention(
145
148
  q, k, v,
149
+ scale = self.scale,
146
150
  is_causal = True
147
151
  )
148
152
 
@@ -174,6 +178,7 @@ class NeuralMemory(Module):
174
178
  default_step_transform_max_lr = 1e-2,
175
179
  pre_rmsnorm = True,
176
180
  post_rmsnorm = True,
181
+ learned_mem_model_weights = True,
177
182
  max_grad_norm: float | None = None,
178
183
  use_accelerated_scan = False,
179
184
  activation: Module | None = None,
@@ -212,6 +217,9 @@ class NeuralMemory(Module):
212
217
  if not exists(model):
213
218
  model = MemoryMLP(dim_head, **default_model_kwargs)
214
219
 
220
+ if not learned_mem_model_weights:
221
+ model.requires_grad_(False)
222
+
215
223
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
216
224
 
217
225
  # the memory is the weights of the model
@@ -9,7 +9,7 @@ from torch.optim import Adam
9
9
  from torch.nn import functional as F
10
10
  from torch.utils.data import DataLoader, Dataset
11
11
 
12
- from titans_pytorch.mac_transformer import MemoryAsContextTransformer
12
+ from titans_pytorch import MemoryAsContextTransformer
13
13
 
14
14
  # constants
15
15
 
@@ -32,6 +32,7 @@ NUM_LONGTERM_MEM = 4
32
32
  NEURAL_MEM_LAYERS = (2, 4)
33
33
  WINDOW_SIZE = 32
34
34
  KV_RECON_LOSS_WEIGHT = 0.
35
+ LEARNED_MEM_MODEL_WEIGHTS = True
35
36
  RUN_NAME = f'mac - {NUM_LONGTERM_MEM} longterm mems, layers {NEURAL_MEM_LAYERS}'
36
37
 
37
38
  # wandb experiment tracker
@@ -90,7 +91,7 @@ def base_decoding(
90
91
  sample_num_times = max(0, seq_len - prompt_seq_len)
91
92
 
92
93
  for _ in tqdm.tqdm(range(sample_num_times)):
93
- logits = net(out)
94
+ logits = net(out, disable_flex_attn = True)
94
95
  logits = logits[:, -1]
95
96
 
96
97
  logits = min_p_filter(logits, min_p = min_p)
@@ -112,9 +113,11 @@ model = MemoryAsContextTransformer(
112
113
  neural_memory_layers = NEURAL_MEM_LAYERS,
113
114
  neural_memory_segment_len = WINDOW_SIZE // 2,
114
115
  aux_kv_recon_loss_weight = KV_RECON_LOSS_WEIGHT,
116
+ use_flex_attn = True,
115
117
  neural_memory_kwargs = dict(
116
118
  dim_head = 64,
117
119
  heads = 4,
120
+ learned_mem_model_weights = LEARNED_MEM_MODEL_WEIGHTS,
118
121
  default_model_kwargs = dict(
119
122
  depth = NEURAL_MEMORY_DEPTH,
120
123
  )
File without changes