titans-pytorch 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -301,6 +301,45 @@ class MemoryAttention(Module):
301
301
 
302
302
  return out
303
303
 
304
+ # associative scan wrapper
305
+
306
+ class AssocScan(Module):
307
+ def __init__(
308
+ self,
309
+ use_accelerated = False
310
+ ):
311
+ super().__init__()
312
+ self.use_accelerated = use_accelerated
313
+
314
+ def forward(self, gates, inputs):
315
+
316
+ if not self.use_accelerated:
317
+ _, outputs = associative_scan(binary_operator, (gates, inputs))
318
+ return outputs
319
+
320
+ from accelerated_scan.triton import scan as triton_scan
321
+ from accelerated_scan.warp import scan as warp_scan
322
+
323
+ scan = triton_scan if gates.is_cuda else warp_scan
324
+
325
+ def accelerate_scan_fn(gates, inputs):
326
+ gates = gates.expand_as(inputs)
327
+ gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
328
+
329
+ seq_len = gates.shape[-1]
330
+ next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
331
+
332
+ gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
333
+ inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
334
+
335
+ outputs = scan(gates.contiguous(), inputs.contiguous())
336
+
337
+ outputs = outputs[..., :seq_len]
338
+ outputs = rearrange(outputs, 'b d n -> b n d')
339
+ return outputs
340
+
341
+ return accelerate_scan_fn(gates, inputs)
342
+
304
343
  # main neural memory
305
344
 
306
345
  def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
@@ -323,6 +362,7 @@ class NeuralMemory(Module):
323
362
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
324
363
  max_mem_layer_modulation = 1e1, # max of 10.
325
364
  attn_pool_chunks = False,
365
+ momentum = True,
326
366
  pre_rmsnorm = True,
327
367
  post_rmsnorm = True,
328
368
  learned_mem_model_weights = True,
@@ -338,6 +378,10 @@ class NeuralMemory(Module):
338
378
 
339
379
  self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
340
380
 
381
+ # associative scan
382
+
383
+ self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
384
+
341
385
  # norms
342
386
 
343
387
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
@@ -423,7 +467,7 @@ class NeuralMemory(Module):
423
467
  self.to_momentum = Sequential(
424
468
  LinearNoBias(dim, heads),
425
469
  Rearrange('b n h -> (b h) n 1')
426
- )
470
+ ) if momentum else None
427
471
 
428
472
  self.to_adaptive_step = Sequential(
429
473
  LinearNoBias(dim, heads),
@@ -462,12 +506,15 @@ class NeuralMemory(Module):
462
506
 
463
507
  self.register_buffer('zero', torch.tensor(0.), persistent = False)
464
508
 
465
- def init_weights_and_momentum(self):
509
+ def init_weights_and_momentum(self, zero_weights = False):
466
510
  params = TensorDict(dict(self.memory_model.named_parameters()))
467
511
 
468
- init_weights = params.clone().zero_()
512
+ init_weights = params
469
513
  init_momentum = params.clone().zero_()
470
514
 
515
+ if zero_weights:
516
+ init_weights = params.clone().zero_()
517
+
471
518
  return init_weights, init_momentum
472
519
 
473
520
  def init_empty_memory_embed(self, batch, seq_len):
@@ -497,14 +544,10 @@ class NeuralMemory(Module):
497
544
 
498
545
  seq = seq[:, :round_down_seq_len]
499
546
 
500
- # curr weights + past weights, in the case that the initial weights are learned
501
-
502
- curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
547
+ # get the weights of the memory network
503
548
 
504
549
  past_state = tuple(TensorDict(d) for d in past_state)
505
- past_weights, past_momentum = past_state
506
-
507
- curr_weights = curr_weights + past_weights
550
+ curr_weights, past_momentum = past_state
508
551
 
509
552
  # derive learned hparams for optimization of memory network
510
553
 
@@ -513,10 +556,13 @@ class NeuralMemory(Module):
513
556
 
514
557
  chunked_seq = self.reduce_to_chunk_rep(seq, chunk_size = chunk_size)
515
558
 
516
- adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
517
559
  decay_factor = self.to_decay_factor(chunked_seq).sigmoid()
518
560
 
519
561
  need_layer_lr_mod = exists(self.to_layer_modulation)
562
+ has_momentum = exists(self.to_momentum)
563
+
564
+ if has_momentum:
565
+ adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
520
566
 
521
567
  if need_layer_lr_mod:
522
568
  layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
@@ -561,57 +607,31 @@ class NeuralMemory(Module):
561
607
 
562
608
  surprises = grads.apply(lambda t: -t)
563
609
 
564
- # determine scan function
565
-
566
- def default_associative_scan(gates, inputs):
567
- _, outputs = associative_scan(binary_operator, (gates, inputs))
568
- return outputs
569
-
570
- if self.use_accelerated_scan:
571
- from accelerated_scan.triton import scan as triton_scan
572
- from accelerated_scan.warp import scan as warp_scan
573
-
574
- scan = triton_scan if seq.is_cuda else warp_scan
575
-
576
- def accelerate_scan_fn(gates, inputs):
577
- gates = gates.expand_as(inputs)
578
- gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
579
-
580
- seq_len = gates.shape[-1]
581
- next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
582
-
583
- gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
584
- inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
585
-
586
- outputs = scan(gates.contiguous(), inputs.contiguous())
587
-
588
- outputs = outputs[..., :seq_len]
589
- outputs = rearrange(outputs, 'b d n -> b n d')
590
- return outputs
591
-
592
- scan_fn = accelerate_scan_fn
593
- else:
594
- scan_fn = default_associative_scan
595
-
596
610
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
597
611
 
598
- next_momentum = TensorDict()
612
+ next_momentum = TensorDict() if has_momentum else None
599
613
  updates = TensorDict()
600
614
 
601
615
  for param_name, surprise in surprises.items():
602
616
 
603
617
  surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
604
618
 
619
+ update = surprise
620
+
605
621
  # derive momentum with associative scan - eq (10)
606
622
 
607
- momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
623
+ if has_momentum:
624
+ update = self.assoc_scan(adaptive_momentum, surprise) # momentum is S / surprise in the paper
625
+ momentum = update
608
626
 
609
627
  # use associative scan again for learned forgetting (weight decay) - eq (13)
610
628
 
611
- update = scan_fn(1. - decay_factor, momentum)
629
+ update = self.assoc_scan(1. - decay_factor, update)
612
630
 
613
631
  updates[param_name] = inverse_pack(update)
614
- next_momentum[param_name] = inverse_pack(momentum)
632
+
633
+ if has_momentum:
634
+ next_momentum[param_name] = inverse_pack(momentum)
615
635
 
616
636
  # compute the next weight per batch
617
637
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.12
3
+ Version: 0.1.15
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
- Requires-Dist: axial-positional-embedding>=0.3.7
38
+ Requires-Dist: axial-positional-embedding>=0.3.9
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: hyper-connections>=0.1.8
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=zxknstaI_Uz47Y8WvZ3S7geJ-TNdqKV5Rvj0Jlw8njs,19271
4
+ titans_pytorch/titans.py,sha256=R0e25ly2uTHkHSZEb-9Eqb0DqtFq8wFBB8iH1T6bYVg,22240
5
+ titans_pytorch-0.1.15.dist-info/METADATA,sha256=SnNsoK4obeOAWFPhQypYJfJWZ_abXKr7WCvLMqFdyg0,6340
6
+ titans_pytorch-0.1.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.15.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=zxknstaI_Uz47Y8WvZ3S7geJ-TNdqKV5Rvj0Jlw8njs,19271
4
- titans_pytorch/titans.py,sha256=eDTqAIDZjSLd34t8M-dCaqVf_s0wZ9jhVIOfXF7E9ts,21887
5
- titans_pytorch-0.1.12.dist-info/METADATA,sha256=dL8HpHt6V5gN8p8px7sc2IgJGqXthE7rULKIrRFCwF8,6340
6
- titans_pytorch-0.1.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.12.dist-info/RECORD,,