titans-pytorch 0.0.64__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,13 +20,21 @@ try:
20
20
  except ImportError:
21
21
  pass
22
22
 
23
- def create_mac_block_mask(seq_len, window_size, persist_mem_len):
23
+ def create_mac_block_mask(seq_len, window_size, persist_mem_len, sliding = False):
24
24
 
25
- def create_mac_mask(b, h, q_idx, kv_idx):
25
+ def create_mac_mask(_, __, q_idx, kv_idx):
26
26
  is_persist_mem = kv_idx < persist_mem_len
27
- causal_mask = q_idx >= (kv_idx - persist_mem_len)
28
- block_diagonal = (q_idx // window_size) == ((kv_idx - persist_mem_len) // window_size)
29
- return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
27
+ kv_without_mem = kv_idx - persist_mem_len
28
+ causal_mask = q_idx >= kv_without_mem
29
+
30
+ if not sliding:
31
+ block_diagonal = (q_idx // window_size) == (kv_without_mem // window_size)
32
+ causal_mask = causal_mask & block_diagonal
33
+ else:
34
+ sliding_mask = (q_idx - kv_without_mem) <= window_size
35
+ causal_mask = causal_mask & sliding_mask
36
+
37
+ return is_persist_mem | (~is_persist_mem & causal_mask)
30
38
 
31
39
  block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
32
40
  return block_mask
@@ -73,7 +81,12 @@ def identity(t):
73
81
  def round_up_multiple(seq, mult):
74
82
  return ceil(seq / mult) * mult
75
83
 
76
- def pad_and_segment_with_inverse(seq, segment_len):
84
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
85
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
86
+ zeros = ((0, 0) * dims_from_right)
87
+ return F.pad(t, (*zeros, *pad), value = value)
88
+
89
+ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
77
90
  batch, seq_len = seq.shape[:2]
78
91
 
79
92
  need_segment = seq_len >= segment_len
@@ -89,13 +102,15 @@ def pad_and_segment_with_inverse(seq, segment_len):
89
102
  if needs_pad:
90
103
  seq = F.pad(seq, (0, 0, 0, padding))
91
104
 
92
- seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
105
+ if fold_into_batch:
106
+ seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
93
107
 
94
108
  def inverse(out):
95
- out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
109
+ if fold_into_batch:
110
+ out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
96
111
 
97
112
  if needs_pad:
98
- out = out[:, :-padding]
113
+ out = out[..., :-padding, :]
99
114
 
100
115
  return out
101
116
 
@@ -127,6 +142,7 @@ class SegmentedAttention(Module):
127
142
  num_longterm_mem_tokens = 0,
128
143
  dim_head = 64,
129
144
  heads = 8,
145
+ sliding = False,
130
146
  accept_value_residual = False,
131
147
  attend_kwargs: dict = dict(),
132
148
  use_flex_attn = False
@@ -155,6 +171,8 @@ class SegmentedAttention(Module):
155
171
  total_segment_len = segment_len + num_longterm_mem_tokens
156
172
  self.total_segment_len = total_segment_len
157
173
 
174
+ self.sliding = sliding # sliding window attn - doubt their non-sliding results being the best. local attention with overlapping windows is very strong
175
+
158
176
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
159
177
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
160
178
 
@@ -210,7 +228,7 @@ class SegmentedAttention(Module):
210
228
  # prep flex attention
211
229
 
212
230
  if not exists(flex_attn_fn):
213
- block_mask = create_mac_block_mask(seq_len, self.total_segment_len, self.num_persist_mem_tokens)
231
+ block_mask = create_mac_block_mask(seq_len, self.total_segment_len, self.num_persist_mem_tokens, self.sliding)
214
232
 
215
233
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
216
234
 
@@ -243,7 +261,7 @@ class SegmentedAttention(Module):
243
261
 
244
262
  # auto pad to multiple
245
263
 
246
- seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
264
+ seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len, fold_into_batch = False)
247
265
 
248
266
  # attention
249
267
 
@@ -260,14 +278,45 @@ class SegmentedAttention(Module):
260
278
  mix = self.to_learned_v_mix(seq)
261
279
  v = v.lerp(value_residual, mix)
262
280
 
263
- # take care of persistent memory key / values
264
-
265
- pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = seq.shape[0])
266
-
267
281
  # relative positions
268
282
 
269
283
  q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
270
284
 
285
+ # fold
286
+
287
+ q, k, v = tuple(rearrange(t, 'b h (w n) d -> (b w) h n d', n = total_segment_len) for t in (q, k, v))
288
+
289
+ # maybe sliding for cpu
290
+
291
+ attend_kwargs = dict()
292
+
293
+ if self.sliding:
294
+ k, v = tuple(rearrange(t, '(b w) ... -> b w ...', b = batch) for t in (k, v))
295
+ k, v = tuple(pad_at_dim(t, (1, 0), value = 0., dim = 1) for t in (k, v))
296
+ k = cat((k[:, :-1], k[:, 1:]), dim = -2)
297
+ v = cat((v[:, :-1], v[:, 1:]), dim = -2)
298
+ k, v = tuple(rearrange(t, 'b w ... -> (b w) ...') for t in (k, v))
299
+
300
+ # take care of masking
301
+
302
+ idx = torch.arange(seq.shape[-2], device = seq.device)
303
+ q_idx = rearrange(idx, '(w n) -> w n', n = total_segment_len)
304
+ k_idx = pad_at_dim(q_idx, (1, 0), dim = 0, value = -1e4)
305
+ k_idx = cat((k_idx[:-1], k_idx[1:]), dim = -1)
306
+
307
+ q_idx = rearrange(q_idx, 'w i -> w i 1')
308
+ k_idx = rearrange(k_idx, 'w j -> w 1 j')
309
+
310
+ sliding_mask = (q_idx - k_idx) <= total_segment_len
311
+ sliding_mask = F.pad(sliding_mask, (self.num_persist_mem_tokens, 0), value = True)
312
+
313
+ sliding_mask = repeat(sliding_mask, 'w i j -> (b w) 1 i j', b = batch)
314
+ attend_kwargs.update(mask = sliding_mask)
315
+
316
+ # take care of persistent memory key / values
317
+
318
+ pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = k.shape[0])
319
+
271
320
  # persistent memory
272
321
 
273
322
  k = cat((pmk, k), dim = -2)
@@ -275,12 +324,14 @@ class SegmentedAttention(Module):
275
324
 
276
325
  # attention
277
326
 
278
- out, _ = self.attend(q, k, v)
327
+ out, _ = self.attend(q, k, v, **attend_kwargs)
279
328
 
280
329
  out = self.merge_heads(out)
281
330
 
282
331
  out = self.to_out(out)
283
332
 
333
+ out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
334
+
284
335
  out = inverse_segment(out)
285
336
 
286
337
  return out, orig_v
@@ -349,7 +400,8 @@ class MemoryAsContextTransformer(Module):
349
400
  neural_memory_kwargs: dict = dict(),
350
401
  neural_memory_layers: tuple[int, ...] | None = None,
351
402
  aux_kv_recon_loss_weight = 0.,
352
- use_flex_attn = False
403
+ use_flex_attn = False,
404
+ sliding_window_attn = False
353
405
  ):
354
406
  super().__init__()
355
407
 
@@ -366,6 +418,10 @@ class MemoryAsContextTransformer(Module):
366
418
 
367
419
  self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
368
420
 
421
+ # maybe sliding window attn
422
+
423
+ self.sliding_window_attn = sliding_window_attn
424
+
369
425
  # hyper conection
370
426
 
371
427
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
@@ -396,7 +452,8 @@ class MemoryAsContextTransformer(Module):
396
452
  use_flex_attn = use_flex_attn,
397
453
  accept_value_residual = not is_first,
398
454
  num_longterm_mem_tokens = num_longterm_mem_tokens,
399
- num_persist_mem_tokens = num_persist_mem_tokens
455
+ num_persist_mem_tokens = num_persist_mem_tokens,
456
+ sliding = sliding_window_attn
400
457
  )
401
458
 
402
459
  mem = None
@@ -489,7 +546,7 @@ class MemoryAsContextTransformer(Module):
489
546
  flex_attn_fn = None
490
547
 
491
548
  if use_flex_attn:
492
- block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens)
549
+ block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens, self.sliding_window_attn)
493
550
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
494
551
 
495
552
  # value residual
titans_pytorch/titans.py CHANGED
@@ -17,6 +17,7 @@ from titans_pytorch.associative_scan import (
17
17
  pad_at_dim
18
18
  )
19
19
 
20
+ import einx
20
21
  from einops import rearrange, repeat, pack, unpack
21
22
  from einops.layers.torch import Rearrange, Reduce
22
23
 
@@ -26,6 +27,7 @@ b - batch
26
27
  n - sequence
27
28
  d - feature dimension
28
29
  c - intra-chunk
30
+ w - num memory network weight parameters
29
31
  """
30
32
 
31
33
  LinearNoBias = partial(Linear, bias = False)
@@ -220,6 +222,8 @@ class NeuralMemory(Module):
220
222
  store_memory_loss_fn: Callable = default_loss_fn,
221
223
  adaptive_step_transform: Callable | None = None,
222
224
  default_step_transform_max_lr = 1e-2,
225
+ per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
226
+ max_mem_layer_modulation = 1e1, # max of 10.
223
227
  pre_rmsnorm = True,
224
228
  post_rmsnorm = True,
225
229
  learned_mem_model_weights = True,
@@ -250,7 +254,7 @@ class NeuralMemory(Module):
250
254
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
251
255
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
252
256
 
253
- self.retrieve_gate = nn.Sequential(
257
+ self.retrieve_gate = Sequential(
254
258
  LinearNoBias(dim, heads),
255
259
  Rearrange('b n h -> b h n 1'),
256
260
  nn.Sigmoid()
@@ -270,6 +274,8 @@ class NeuralMemory(Module):
270
274
 
271
275
  self.memory_model = model
272
276
 
277
+ self.num_memory_parameter_tensors = len(set(model.parameters()))
278
+
273
279
  # the chunk size within the paper where adaptive step, momentum, weight decay are shared
274
280
 
275
281
  self.chunk_size = chunk_size
@@ -299,15 +305,14 @@ class NeuralMemory(Module):
299
305
  nn.init.normal_(self.empty_memory_embed, std = 0.02)
300
306
 
301
307
  # learned adaptive learning rate and momentum
302
- # todo - explore mlp layerwise learned lr / momentum
303
308
 
304
- self.to_momentum = nn.Sequential(
309
+ self.to_momentum = Sequential(
305
310
  Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
306
311
  LinearNoBias(dim, heads),
307
312
  Rearrange('b n h -> (b h) n 1')
308
313
  )
309
314
 
310
- self.to_adaptive_step = nn.Sequential(
315
+ self.to_adaptive_step = Sequential(
311
316
  LinearNoBias(dim, heads),
312
317
  Rearrange('b n h -> (b h) n')
313
318
  )
@@ -317,13 +322,24 @@ class NeuralMemory(Module):
317
322
 
318
323
  self.adaptive_step_transform = adaptive_step_transform
319
324
 
325
+ # per layer learning rate modulation
326
+
327
+ self.to_layer_modulation = Sequential(
328
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
329
+ LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
330
+ Rearrange('b n (h w) -> w (b h) n', h = heads),
331
+ nn.Sigmoid()
332
+ ) if per_parameter_lr_modulation else None
333
+
334
+ self.max_mem_layer_modulation = max_mem_layer_modulation
335
+
320
336
  # allow for softclamp the gradient norms for storing memories
321
337
 
322
338
  self.max_grad_norm = max_grad_norm
323
339
 
324
340
  # weight decay factor
325
341
 
326
- self.to_decay_factor = nn.Sequential(
342
+ self.to_decay_factor = Sequential(
327
343
  Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
328
344
  LinearNoBias(dim, heads),
329
345
  Rearrange('b n h -> (b h) n 1')
@@ -387,6 +403,11 @@ class NeuralMemory(Module):
387
403
  adaptive_momentum = self.to_momentum(seq).sigmoid()
388
404
  decay_factor = self.to_decay_factor(seq).sigmoid()
389
405
 
406
+ need_layer_lr_mod = exists(self.to_layer_modulation)
407
+
408
+ if need_layer_lr_mod:
409
+ layer_lr_mod = self.to_layer_modulation(seq) * self.max_mem_layer_modulation
410
+
390
411
  # keys and values
391
412
 
392
413
  keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
@@ -418,6 +439,11 @@ class NeuralMemory(Module):
418
439
 
419
440
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
420
441
 
442
+ # maybe per layer modulation
443
+
444
+ if need_layer_lr_mod:
445
+ grads = TensorDict({name: einx.multiply('b h, b h ... -> b h ...', layer_lr_mod, t) for layer_lr_mod, (name, t) in zip(layer_lr_mod, grads.items())})
446
+
421
447
  # negative gradients, adaptive lr already applied as loss weight
422
448
 
423
449
  surprises = grads.apply(lambda t: -t)
@@ -469,7 +495,7 @@ class NeuralMemory(Module):
469
495
 
470
496
  # use associative scan again for learned forgetting (weight decay) - eq (13)
471
497
 
472
- update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
498
+ update = scan_fn(1. - decay_factor, momentum)
473
499
 
474
500
  updates[param_name] = inverse_pack(update)
475
501
  next_momentum[param_name] = inverse_pack(momentum)
@@ -580,7 +606,6 @@ class NeuralMemory(Module):
580
606
 
581
607
  past_weights, _ = past_state
582
608
 
583
-
584
609
  retrieved = self.retrieve_memories(seq, past_weights + updates)
585
610
 
586
611
  if not return_aux_kv_loss:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.64
3
+ Version: 0.1.0
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -37,6 +37,7 @@ Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.5
39
39
  Requires-Dist: einops>=0.8.0
40
+ Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: hyper-connections>=0.1.8
41
42
  Requires-Dist: ninja
42
43
  Requires-Dist: rotary-embedding-torch
@@ -58,6 +59,10 @@ Description-Content-Type: text/markdown
58
59
 
59
60
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
60
61
 
62
+ ## Appreciation
63
+
64
+ - [@sentialx](https://github.com/sentialx) for sharing his early experimental results with me
65
+
61
66
  ## Install
62
67
 
63
68
  ```bash
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=7PHBCbeB1LhHY5s3zAyYF0L3Mm7CNy4TOBbcpLX6bNE,17686
4
+ titans_pytorch/titans.py,sha256=L3Mu6pDnimD4MNn_832trFEJgXOPjxSdTrB9jiSUSTk,18533
5
+ titans_pytorch-0.1.0.dist-info/METADATA,sha256=LuWDzv-NbGxYKOMThN_WKQWDueyIsOAMSwwiE_BDraI,4595
6
+ titans_pytorch-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.0.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=7voYtbD_ErCD_JjvwhAiunUWtSIsIxGJAaf2aRB3c2s,15349
4
- titans_pytorch/titans.py,sha256=y6lJRErIoM6T2aTVFlf1GxSB0cpsmBZdSIj1DCHUCQ8,17486
5
- titans_pytorch-0.0.64.dist-info/METADATA,sha256=K63jobSfTdn-aFpEpZgolu4zSIvgUzF2rDuoCHGXkgE,4457
6
- titans_pytorch-0.0.64.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.64.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.64.dist-info/RECORD,,