titans-pytorch 0.0.65__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -17,6 +17,7 @@ from titans_pytorch.associative_scan import (
17
17
  pad_at_dim
18
18
  )
19
19
 
20
+ import einx
20
21
  from einops import rearrange, repeat, pack, unpack
21
22
  from einops.layers.torch import Rearrange, Reduce
22
23
 
@@ -26,6 +27,7 @@ b - batch
26
27
  n - sequence
27
28
  d - feature dimension
28
29
  c - intra-chunk
30
+ w - num memory network weight parameters
29
31
  """
30
32
 
31
33
  LinearNoBias = partial(Linear, bias = False)
@@ -220,6 +222,8 @@ class NeuralMemory(Module):
220
222
  store_memory_loss_fn: Callable = default_loss_fn,
221
223
  adaptive_step_transform: Callable | None = None,
222
224
  default_step_transform_max_lr = 1e-2,
225
+ per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
226
+ max_mem_layer_modulation = 1e1, # max of 10.
223
227
  pre_rmsnorm = True,
224
228
  post_rmsnorm = True,
225
229
  learned_mem_model_weights = True,
@@ -250,7 +254,7 @@ class NeuralMemory(Module):
250
254
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
251
255
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
252
256
 
253
- self.retrieve_gate = nn.Sequential(
257
+ self.retrieve_gate = Sequential(
254
258
  LinearNoBias(dim, heads),
255
259
  Rearrange('b n h -> b h n 1'),
256
260
  nn.Sigmoid()
@@ -270,6 +274,8 @@ class NeuralMemory(Module):
270
274
 
271
275
  self.memory_model = model
272
276
 
277
+ self.num_memory_parameter_tensors = len(set(model.parameters()))
278
+
273
279
  # the chunk size within the paper where adaptive step, momentum, weight decay are shared
274
280
 
275
281
  self.chunk_size = chunk_size
@@ -299,15 +305,14 @@ class NeuralMemory(Module):
299
305
  nn.init.normal_(self.empty_memory_embed, std = 0.02)
300
306
 
301
307
  # learned adaptive learning rate and momentum
302
- # todo - explore mlp layerwise learned lr / momentum
303
308
 
304
- self.to_momentum = nn.Sequential(
309
+ self.to_momentum = Sequential(
305
310
  Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
306
311
  LinearNoBias(dim, heads),
307
312
  Rearrange('b n h -> (b h) n 1')
308
313
  )
309
314
 
310
- self.to_adaptive_step = nn.Sequential(
315
+ self.to_adaptive_step = Sequential(
311
316
  LinearNoBias(dim, heads),
312
317
  Rearrange('b n h -> (b h) n')
313
318
  )
@@ -317,13 +322,24 @@ class NeuralMemory(Module):
317
322
 
318
323
  self.adaptive_step_transform = adaptive_step_transform
319
324
 
325
+ # per layer learning rate modulation
326
+
327
+ self.to_layer_modulation = Sequential(
328
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
329
+ LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
330
+ Rearrange('b n (h w) -> w (b h) n', h = heads),
331
+ nn.Sigmoid()
332
+ ) if per_parameter_lr_modulation else None
333
+
334
+ self.max_mem_layer_modulation = max_mem_layer_modulation
335
+
320
336
  # allow for softclamp the gradient norms for storing memories
321
337
 
322
338
  self.max_grad_norm = max_grad_norm
323
339
 
324
340
  # weight decay factor
325
341
 
326
- self.to_decay_factor = nn.Sequential(
342
+ self.to_decay_factor = Sequential(
327
343
  Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
328
344
  LinearNoBias(dim, heads),
329
345
  Rearrange('b n h -> (b h) n 1')
@@ -387,6 +403,11 @@ class NeuralMemory(Module):
387
403
  adaptive_momentum = self.to_momentum(seq).sigmoid()
388
404
  decay_factor = self.to_decay_factor(seq).sigmoid()
389
405
 
406
+ need_layer_lr_mod = exists(self.to_layer_modulation)
407
+
408
+ if need_layer_lr_mod:
409
+ layer_lr_mod = self.to_layer_modulation(seq) * self.max_mem_layer_modulation
410
+
390
411
  # keys and values
391
412
 
392
413
  keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
@@ -418,6 +439,11 @@ class NeuralMemory(Module):
418
439
 
419
440
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
420
441
 
442
+ # maybe per layer modulation
443
+
444
+ if need_layer_lr_mod:
445
+ grads = TensorDict({name: einx.multiply('b h, b h ... -> b h ...', layer_lr_mod, t) for layer_lr_mod, (name, t) in zip(layer_lr_mod, grads.items())})
446
+
421
447
  # negative gradients, adaptive lr already applied as loss weight
422
448
 
423
449
  surprises = grads.apply(lambda t: -t)
@@ -469,7 +495,7 @@ class NeuralMemory(Module):
469
495
 
470
496
  # use associative scan again for learned forgetting (weight decay) - eq (13)
471
497
 
472
- update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
498
+ update = scan_fn(1. - decay_factor, momentum)
473
499
 
474
500
  updates[param_name] = inverse_pack(update)
475
501
  next_momentum[param_name] = inverse_pack(momentum)
@@ -580,7 +606,6 @@ class NeuralMemory(Module):
580
606
 
581
607
  past_weights, _ = past_state
582
608
 
583
-
584
609
  retrieved = self.retrieve_memories(seq, past_weights + updates)
585
610
 
586
611
  if not return_aux_kv_loss:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.65
3
+ Version: 0.1.0
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -37,6 +37,7 @@ Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.5
39
39
  Requires-Dist: einops>=0.8.0
40
+ Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: hyper-connections>=0.1.8
41
42
  Requires-Dist: ninja
42
43
  Requires-Dist: rotary-embedding-torch
@@ -58,6 +59,10 @@ Description-Content-Type: text/markdown
58
59
 
59
60
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
60
61
 
62
+ ## Appreciation
63
+
64
+ - [@sentialx](https://github.com/sentialx) for sharing his early experimental results with me
65
+
61
66
  ## Install
62
67
 
63
68
  ```bash
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=7PHBCbeB1LhHY5s3zAyYF0L3Mm7CNy4TOBbcpLX6bNE,17686
4
+ titans_pytorch/titans.py,sha256=L3Mu6pDnimD4MNn_832trFEJgXOPjxSdTrB9jiSUSTk,18533
5
+ titans_pytorch-0.1.0.dist-info/METADATA,sha256=LuWDzv-NbGxYKOMThN_WKQWDueyIsOAMSwwiE_BDraI,4595
6
+ titans_pytorch-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.0.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=7PHBCbeB1LhHY5s3zAyYF0L3Mm7CNy4TOBbcpLX6bNE,17686
4
- titans_pytorch/titans.py,sha256=y6lJRErIoM6T2aTVFlf1GxSB0cpsmBZdSIj1DCHUCQ8,17486
5
- titans_pytorch-0.0.65.dist-info/METADATA,sha256=oDjEiufwOninsFDoCGbu691LXc1mey2OT7j6PNzkz0Q,4457
6
- titans_pytorch-0.0.65.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.65.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.65.dist-info/RECORD,,